diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index cefdd7cc4033a..9cda7862ccb0f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } +def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, + Arguments<(ins LLVM_i8Ptr_shared:$ptr, + Variadic:$sources, + MMALayoutAttr:$layout)> { + let summary = "cooperative matrix store"; + let description = [{ + Collectively store one or more matrices across all threads in a warp to the + location indicated by the address operand $ptr in shared memory. + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) + }]; + + let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; + let extraClassDefinition = [{ + std::string $cppClass::getPtx() { + int d = getSources().size(); + std::string ptx = "stmatrix.sync.aligned"; + ptx += ".x" + std::to_string(d); + if (getLayout() == NVVM::MMALayout::col) + ptx += ".trans"; + if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}"; + if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}"; + if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; + return ptx; + } + }]; + let hasVerifier = 1; +} + def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, Results<(outs AnyType:$res)>, Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 92df023c797b1..3736978505707 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() { return success(); } +LogicalResult NVVM::StMatrixOp::verify() { + unsigned addressSpace = + llvm::cast(getPtr().getType()).getAddressSpace(); + if (addressSpace != NVVM::kSharedMemorySpace) + return emitOpError("expected source pointer in memory space 3"); + + int numMatrix = getSources().size(); + if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) + return emitOpError("expected num attribute to be 1, 2 or 4"); + + return success(); +} + FailureOr getAllowedSizeK(NVVM::WGMMATypes typeA) { if (typeA == NVVM::WGMMATypes::tf32) return 8; diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 0d0ac9637438a..3bb0ab90775ed 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() { // ----- +// CHECK-LABEL: @stmatrix( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, +// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32, +// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32, +// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, +// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) +llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () + nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + +// ----- + // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"