Skip to content

Commit 1774a1a

Browse files
tzj-fxzchengyupkujohnnynunez
authored
[Feature] Add 1D TMA support (#761)
* [Feature] Add 1D TMA support - Check the contiguous conditions of 1D TMA copy - Add new interface and params order of `tma_load` and `tma_store` call - Add 1D `tma_store` interface in sm90 template - Add elementwise kernel for 1D TMA example * [Lint] * [BugFix] Add conditions for 1D TMA copy on non-swizzle shared tensors * [Lint] * [BugFix] 1D TMA load * [README] Update GDN README for clarity and add acknowledgements (#758) - Improved formatting and clarity of the GDN kernel implementation description. - Updated requirement section to list dependencies in a clearer format. - Added an acknowledgements section to credit the developers and the Xiaomi LLM-Core Team for their contributions. * cutlass v4.2.0 supporting cuda 13 (#760) * [Lint] * [Lint] * [MXFP4] Add test for bf16&mxfp4 gemm * [BugFix] * [Lint] --------- Co-authored-by: Yu Cheng <54519279+chengyupku@users.noreply.github.com> Co-authored-by: Johnny <johnnync13@gmail.com>
1 parent e05a20a commit 1774a1a

File tree

10 files changed

+291
-17
lines changed

10 files changed

+291
-17
lines changed

examples/dequantize_gemm/test_example_dequantize_gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import example_dequant_gemv_fp16xint4
44
import example_dequant_gemm_fp4_hopper
5+
import example_dequant_gemm_bf16_mxfp4_hopper
56

67

78
@tilelang.testing.requires_cuda
@@ -15,5 +16,11 @@ def test_example_dequant_gemm_fp4_hopper():
1516
example_dequant_gemm_fp4_hopper.main()
1617

1718

19+
@tilelang.testing.requires_cuda
20+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
21+
def test_example_dequant_gemm_bf16_mxfp4_hopper():
22+
example_dequant_gemm_bf16_mxfp4_hopper.main()
23+
24+
1825
if __name__ == "__main__":
1926
tilelang.testing.main()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import argparse
2+
import tilelang
3+
import tilelang.language as T
4+
import torch
5+
6+
7+
def ref_program(x, y):
8+
return x + y
9+
10+
11+
@tilelang.jit(out_idx=[-1])
12+
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
13+
14+
@T.prim_func
15+
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
16+
(M, N), out_dtype)):
17+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
18+
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
19+
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
20+
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
21+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
22+
23+
T.copy(A[by * block_M, bx * block_N], A_shared)
24+
T.copy(B[by * block_M, bx * block_N], B_shared)
25+
for (local_y, local_x) in T.Parallel(block_M, block_N):
26+
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
27+
T.copy(C_local, C_shared)
28+
T.copy(C_shared, C[by * block_M, bx * block_N])
29+
30+
return elem_add
31+
32+
33+
def main():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--m", type=int, default=128)
36+
parser.add_argument("--n", type=int, default=128)
37+
args, _ = parser.parse_known_args()
38+
M, N = args.m, args.n
39+
40+
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
41+
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
42+
43+
# Default config
44+
config = {"block_M": 128, "block_N": 128, "threads": 128}
45+
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
46+
47+
out = kernel(a, b)
48+
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
49+
print("All passed!")
50+
51+
52+
if __name__ == "__main__":
53+
main()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import tilelang.testing
22
import example_elementwise_add
3+
import example_elementwise_add_tma_1d
34

45

56
def test_example_elementwise_add():
67
example_elementwise_add.main()
78

89

10+
def test_example_elementwise_add_tma_1d():
11+
example_elementwise_add_tma_1d.main()
12+
13+
914
if __name__ == "__main__":
1015
tilelang.testing.main()

examples/gdn/example_wy_fast_bwd_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020
import torch.nn.functional as F
21-
from utils import assert_similar
2221

2322
torch.random.manual_seed(0)
2423
torch.set_printoptions(profile="full")
@@ -504,6 +503,7 @@ def run_test(
504503
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
505504
dim=-1)
506505

506+
from utils import assert_similar
507507
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
508508
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
509509
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)

src/op/copy.cc

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -772,19 +772,133 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
772772
stride *= s;
773773
}
774774

775+
Array<PrimExpr> global_indices;
776+
for (auto r : global_range) {
777+
global_indices.push_back(r->min);
778+
}
779+
std::vector<PrimExpr> global_strides;
780+
PrimExpr global_stride = 1;
781+
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
782+
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
783+
global_strides.insert(global_strides.begin(), global_stride);
784+
global_stride *= s;
785+
}
786+
775787
ICHECK(strides.size() == indices.size())
776788
<< "strides.size() != indices.size()" << strides.size() << " "
777789
<< indices.size();
778790
PrimExpr offset = 0;
779791
for (size_t i = 0; i < indices.size(); i++) {
780792
offset += indices[i] * strides[i];
781793
}
794+
PrimExpr global_offset = 0;
795+
for (size_t i = 0; i < global_indices.size(); i++) {
796+
global_offset += global_indices[i] * global_strides[i];
797+
}
798+
auto shared_tensor_before_remap = shared_tensor;
782799
Layout shared_layout;
783800
if (T.layout_map.count(shared_tensor)) {
784801
shared_layout = T.layout_map[shared_tensor];
785802
shared_tensor = T.buffer_remap[shared_tensor];
786803
}
787804

805+
// Add 1D TMA copy when the global and shared memory is contiguous
806+
{
807+
// Check if shared_tensor->name is present in T.buffer_var_gemm
808+
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
809+
bool shared_is_contiguous = true;
810+
for (const auto &v : T.buffer_var_gemm) {
811+
if (v->name_hint == shared_tensor->name) {
812+
shared_is_contiguous = false;
813+
break;
814+
}
815+
}
816+
bool shared_not_full_dim_encounter = false;
817+
for (ssize_t i = shared_range.size() - 1; i >= 0; --i) {
818+
if (!shared_not_full_dim_encounter) {
819+
if (!analyzer->CanProve(shared_range[i]->extent ==
820+
shared_tensor_before_remap->shape[i] &&
821+
shared_range[i]->min == 0)) {
822+
shared_not_full_dim_encounter = true;
823+
}
824+
} else {
825+
if (!analyzer->CanProve(shared_range[i]->extent == 1)) {
826+
shared_is_contiguous = false;
827+
break;
828+
}
829+
}
830+
}
831+
// Currently we check the empty stride of global tensor
832+
bool global_is_contiguous = !global_tensor->strides.empty();
833+
bool global_not_full_dim_encounter = false;
834+
for (ssize_t i = global_range.size() - 1; i >= 0; --i) {
835+
if (!global_not_full_dim_encounter) {
836+
if (!analyzer->CanProve(global_range[i]->extent ==
837+
global_tensor->shape[i] &&
838+
global_range[i]->min == 0)) {
839+
global_not_full_dim_encounter = true;
840+
}
841+
} else {
842+
if (!analyzer->CanProve(global_range[i]->extent == 1)) {
843+
global_is_contiguous = false;
844+
break;
845+
}
846+
}
847+
}
848+
// Ensure there is element match and no OOB
849+
PrimExpr shared_elements = 1;
850+
for (size_t i = 0; i < shared_range.size(); i++) {
851+
shared_elements *= shared_range[i]->extent;
852+
}
853+
PrimExpr global_elements = 1;
854+
for (size_t i = 0; i < global_range.size(); i++) {
855+
global_elements *= global_range[i]->extent;
856+
}
857+
bool element_match =
858+
analyzer->CanProveEqual(shared_elements, global_elements);
859+
bool no_oob = true;
860+
for (size_t i = 0; i < shared_range.size(); i++) {
861+
if (!analyzer->CanProve(shared_range[i]->min + shared_range[i]->extent <=
862+
shared_tensor_before_remap->shape[i])) {
863+
no_oob = false;
864+
break;
865+
}
866+
}
867+
for (size_t i = 0; i < global_range.size(); i++) {
868+
if (!analyzer->CanProve(global_range[i]->min + global_range[i]->extent <=
869+
global_tensor->shape[i])) {
870+
no_oob = false;
871+
break;
872+
}
873+
}
874+
// Add 1D TMA copy only for load
875+
if (shared_is_contiguous && global_is_contiguous && element_match &&
876+
no_oob && is_load) {
877+
PrimExpr elements = analyzer->Simplify(shared_elements);
878+
PrimExpr shared_addr = shared_tensor_before_remap.access_ptr(
879+
is_load ? 2 : 1, DataType::Handle(), 1, offset, elements);
880+
PrimExpr global_addr = global_tensor.access_ptr(
881+
is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
882+
Stmt tma_copy;
883+
if (is_load) {
884+
// the zero is a placeholder for mbarrier id
885+
tma_copy =
886+
Evaluate(Call(DataType::Handle(), tma_load(),
887+
{shared_addr, global_addr, 0,
888+
elements * shared_tensor_before_remap->dtype.bytes(),
889+
this->eviction_policy}));
890+
} else {
891+
tma_copy =
892+
Evaluate(Call(DataType::Handle(), tma_store(),
893+
{global_addr, shared_addr,
894+
elements * shared_tensor_before_remap->dtype.bytes(),
895+
this->eviction_policy}));
896+
}
897+
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
898+
return tma_copy;
899+
}
900+
}
901+
788902
TMADesc desc;
789903
// Verify copy rank
790904
desc.rank = global_tensor->shape.size();
@@ -1221,10 +1335,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
12211335

12221336
// Register the Copy operation with TVM's TIR system
12231337
// This makes the copy operation available for use in TVM programs
1224-
// - Takes 4 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma
1338+
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
1339+
// eviction_policy
12251340
// - Marked as opaque since it has side effects (memory writes)
12261341
TIR_REGISTER_TL_OP(Copy, copy)
1227-
.set_num_inputs(4)
1342+
.set_num_inputs(5)
12281343
.set_attr<TCallEffectKind>("TCallEffectKind",
12291344
Integer(CallEffectKind::kOpaque));
12301345

src/op/op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct LowerArgs {
4949
AddWorkspaceCallback AddWorkspace;
5050
LayoutMap layout_map;
5151
Map<Buffer, Buffer> buffer_remap;
52+
Array<Var> buffer_var_gemm;
5253
};
5354

5455
struct LayoutInferArgs {

src/tl_templates/cuda/copy_sm90.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar,
171171
: "memory");
172172
}
173173

174+
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
175+
TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
176+
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
177+
asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
178+
".L2::cache_hint [%0], [%1], %2, %3;"
179+
:
180+
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
181+
:);
182+
}
183+
174184
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
175185
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
176186
void const *const smem_ptr, int32_t const &crd0) {

src/transform/inject_tma_barrier.cc

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,17 @@ class TmaTraitsCollector : public StmtExprVisitor {
6262
private:
6363
void VisitExpr_(const CallNode *call) final {
6464
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
65-
Call access_ptr = Downcast<Call>(call->args[2]);
66-
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
67-
int type_bytes = access_ptr->args[0]->dtype.bytes();
68-
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
65+
auto arg0 = call->args[0].as<Call>();
66+
if (call->op.same_as(tma_load()) && arg0 &&
67+
!arg0.value()->op.same_as(create_tma_descriptor())) {
68+
// 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
69+
bulk_copy_bytes = call->args[3] * loop_extents;
70+
} else {
71+
Call access_ptr = Downcast<Call>(call->args[2]);
72+
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
73+
int type_bytes = access_ptr->args[0]->dtype.bytes();
74+
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
75+
}
6976
}
7077
StmtExprVisitor::VisitExpr_(call);
7178
}
@@ -155,10 +162,15 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
155162

156163
PrimExpr VisitExpr_(const CallNode *op) {
157164
if (op->op.same_as(tma_load())) {
165+
auto arg0 = op->args[0].as<Call>();
166+
bool is_1d_tma_load =
167+
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
168+
op->op.same_as(tma_load());
158169
visited_tma_load_ = true;
159170
Array<PrimExpr> new_args = op->args;
160-
new_args.Set(1, Call(DataType::Handle(), get_mbarrier(),
161-
{IntImm(DataType::Int(32), 0)}));
171+
new_args.Set(is_1d_tma_load ? 2 : 1,
172+
Call(DataType::Handle(), get_mbarrier(),
173+
{IntImm(DataType::Int(32), 0)}));
162174
return Call(op->dtype, op->op, new_args);
163175
}
164176
return IRMutatorWithAnalyzer::VisitExpr_(op);
@@ -443,7 +455,14 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
443455
<< "tma_load must be in the tma_op_to_barrier_id_";
444456
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
445457
auto new_args = op->args;
446-
new_args.Set(1, barrier_id);
458+
auto arg0 = op->args[0].as<Call>();
459+
auto is_1d_tma_load =
460+
arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
461+
if (is_1d_tma_load) {
462+
new_args.Set(2, barrier_id);
463+
} else {
464+
new_args.Set(1, barrier_id);
465+
}
447466
return Call(op->dtype, op->op, new_args);
448467
} else if (op->op.same_as(mbarrier_expect_tx())) {
449468
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))

0 commit comments

Comments
 (0)