Skip to content

Commit 10477be

Browse files
Observer007joker-eph
authored andcommitted
Add TMA Store operation to the NVVM dialect
Reviewed By: guraypp Differential Revision: https://reviews.llvm.org/D159535
1 parent 774116b commit 10477be

File tree

3 files changed

+79
-15
lines changed

3 files changed

+79
-15
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,28 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
14901490
let hasVerifier = 1;
14911491
}
14921492

1493+
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1494+
Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor,
1495+
LLVM_i64ptr_shared:$srcMem,
1496+
Variadic<I32>:$coordinates)> {
1497+
let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
1498+
let extraClassDefinition = [{
1499+
std::string $cppClass::getPtx() {
1500+
int dim = getCoordinates().size();
1501+
std::string ptx = "cp.async.bulk.tensor.";
1502+
ptx += std::to_string(dim) + "d.";
1503+
ptx += "global.shared::cta.bulk_group";
1504+
if(dim == 1) ptx += " [%0, {%2} ], [%1];";
1505+
if(dim == 2) ptx += " [%0, {%2, %3} ], [%1];";
1506+
if(dim == 3) ptx += " [%0, {%2, %3, %4} ], [%1];";
1507+
if(dim == 4) ptx += " [%0, {%2, %3, %4, %5} ], [%1];";
1508+
if(dim == 5) ptx += " [%0, {%2, %3, %4, %5, %6} ], [%1];";
1509+
return ptx;
1510+
}
1511+
}];
1512+
let hasVerifier = 1;
1513+
}
1514+
14931515
//===----------------------------------------------------------------------===//
14941516
// NVVM Wgmma Ops
14951517
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
8181
return success();
8282
}
8383

84+
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
85+
if (getCoordinates().size() > 5)
86+
return emitError("Maximum 5 coordinates and dimension is supported.");
87+
return success();
88+
}
89+
8490
LogicalResult CpAsyncOp::verify() {
8591
if (getModifier() != LoadCacheModifierKind::CG &&
8692
getModifier() != LoadCacheModifierKind::CA)

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
// and the generic `convert-to-llvm` pass.
55
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
66

7-
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx
7+
// todo: remove extra space between `CHECK/CHECK-LABEL` and `:`
8+
9+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
810
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
9-
//CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
11+
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
1012
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
1113
llvm.return
1214
}
1315

14-
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
16+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
1517
llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
1618
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r"
1719
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
@@ -32,7 +34,7 @@ llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i
3234
llvm.return
3335
}
3436

35-
// CHECK-LABEL : @async_cp
37+
// CHECK-LABEL: @async_cp
3638
func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
3739
// CHECK : nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
3840
nvvm.cp.async.shared.global %dst, %src, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
@@ -41,7 +43,7 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
4143
return
4244
}
4345

44-
// CHECK-LABEL : @async_cp_zfill
46+
// CHECK-LABEL: @async_cp_zfill
4547
func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
4648
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
4749
nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
@@ -50,41 +52,75 @@ func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
5052
return
5153
}
5254

53-
// CHECK-LABEL : @tma_load_1d
55+
// CHECK-LABEL: @tma_load_1d
5456
func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32) {
55-
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3}], [$2];", "l,r,r,r"
57+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r"
5658
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32
5759
return
5860
}
5961

60-
// CHECK-LABEL : @tma_load_2d
62+
// CHECK-LABEL: @tma_load_2d
6163
func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
62-
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4}], [$2];", "l,r,r,r,r"
64+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r"
6365
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
6466
return
6567
}
6668

67-
// CHECK-LABEL : @tma_load_3d
69+
// CHECK-LABEL: @tma_load_3d
6870
func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32) {
69-
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5}], [$2];", "l,r,r,r,r,r"
71+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5} ], [$2];", "r,l,r,r,r,r"
7072
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
7173
return
7274
}
7375

74-
// CHECK-LABEL : @tma_load_4d
76+
// CHECK-LABEL: @tma_load_4d
7577
func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32) {
76-
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6}], [$2];", "l,r,r,r,r,r,r"
78+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6} ], [$2];", "r,l,r,r,r,r,r"
7779
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
7880
return
7981
}
8082

81-
// CHECK-LABEL : @tma_load_5d
83+
// CHECK-LABEL: @tma_load_5d
8284
func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32) {
83-
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7}], [$2];", "l,r,r,r,r,r,r,r"
85+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7} ], [$2];", "r,l,r,r,r,r,r,r"
8486
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
8587
return
8688
}
8789

90+
// CHECK-LABEL: @tma_store_1d
91+
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32) {
92+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
93+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
94+
return
95+
}
96+
97+
// CHECK-LABEL: @tma_store_2d
98+
func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
99+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
100+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
101+
return
102+
}
103+
104+
// CHECK-LABEL: @tma_store_3d
105+
func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32) {
106+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
107+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
108+
return
109+
}
110+
111+
// CHECK-LABEL: @tma_store_4d
112+
func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32) {
113+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
114+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
115+
return
116+
}
117+
118+
// CHECK-LABEL: @tma_store_5d
119+
func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32) {
120+
// CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
121+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
122+
return
123+
}
88124

89125
// CHECK-LABEL : @wgmma_execute
90126
func.func @wgmma_execute() {

0 commit comments

Comments
 (0)