-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVGPU] Introduce nvgpu.wargroup.mma.store
Op for Hopper GPUs
#65441
Conversation
The cursed typo "wargroup" is on :). |
@llvm/pr-subscribers-mlir-nvgpu Changes[MLIR][NVGPU] Introduce This work introduces a new operation called An example of fragmentation is given here : The
Here's an example usage of the
Depends on #65440 Full diff: https://github.com/llvm/llvm-project/pull/65441.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+ let description = [{
+ The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
+ in $matrixD to give memref.
+
+ [See the details of register fragment layout for accumulator matrix D]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
+
+ Note that, the op must be run with warp group.
+ }];
+
+ let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+ Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+
+ let assemblyFormat = [{
+ `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
using Base::Base;
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+ arith::ArithDialect>();
}
void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
populateNVGPUToNVVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+ target.addLegalDialect<::mlir::arith::ArithDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,6 +1301,82 @@ struct NVGPUWarpgroupMmaOpLowering
}
};
+struct NVGPUWarpgroupMmaStoreOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+ void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ int offset) const {
+ Location loc = op->getLoc();
+ Type i32 = rewriter.getI32Type();
+
+ auto makeConst = [&](int32_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, i32, rewriter.getI32IntegerAttr(index));
+ };
+ Value c4 = makeConst(4);
+ Value c32 = makeConst(kWarpSize);
+ Value c8 = makeConst(8);
+ Value c2 = makeConst(2);
+ Value c1 = makeConst(1);
+ Value c16 = makeConst(16);
+
+ auto makeMul = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+ };
+ auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+ };
+
+ Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+ Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+ Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+ Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+ Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+ auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+ TypedValue<::mlir::MemRefType> memref) {
+ Type it = rewriter.getIndexType();
+ Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+ Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+ Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+ Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+ Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+ rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+ rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+ };
+
+ Value tj = makeMul(lane4modId, c2);
+ Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+ if (offset)
+ ti = makeAdd(ti, makeConst(offset));
+ for (int i = 0; i < 2; ++i) {
+ Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+ for (int j = 0; j < 16; ++j) {
+ Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+ int sIndex = i * 2 + j * 4;
+ makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+ }
+ }
+ }
+
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ int offset = 0;
+ for (auto result : adaptor.getMatrixD()) {
+ auto stype = result.getType().cast<LLVM::LLVMStructType>();
+ storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+ offset += stype.getBody().size();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1315,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
+ NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,37 @@ LogicalResult WarpgroupMmaOp::verify() {
return success();
}
+LogicalResult WarpgroupMmaStoreOp::verify() {
+ Type stype = getMatrixD()
+ .front()
+ .getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented();
+
+ for (auto result : getMatrixD()) {
+ auto resultStype = result.getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented()
+ .dyn_cast<LLVM::LLVMStructType>();
+ if (!resultStype)
+ return emitOpError() << "result is " << result.getType()
+ << " but must keep type of llvm struct";
+ if (stype != resultStype)
+ return emitOpError() << "all results must be the same type";
+
+ // todo improve this limitation
+ if (!resultStype.getBody().front().isF32()) {
+ return emitOpError() << "supporst only f32 results for the time being";
+ }
+ }
+
+ if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+ return emitOpError() << "all element types must be equal ";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks almost good to me.
I'd like to see a bit more comments and I believe there some cmake/blaze missing changes.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the singleton implementation for the warpsize value is broken and anyway overkill.
I believe we still miss a change in a cmake file.
Other than that couple of nits but LGTM.
This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref. An example of fragmentation is given here : https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d The `warpgroup.mma.store` does followings: 1) Takes one or more fragmented results matrix. 2) Calculates indexes per thread in warp group and stores the data into give memref. Here's an example usage of the `nvgpu.warpgroup.mma` operation: ``` // Performs matmul, results are fragmented and in registers %res, %res2 = nvgpu.warpgroup.mma ... // Stores the fragmented result to the give memory nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD : !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>, !nvgpu.warpgroup.result<tensor = !llvm.struct<...>> to memref<128x128xf32,3> ``` Depends on llvm#65440
This PR introduces a new Op called
warpgroup.mma.store
to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted result(s)nvgpu.warpgroup.accumulator
produced bywarpgroup.mma
to the given memref.An example of fragmentated matrix is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d
The
warpgroup.mma.store
does followings:nvgpu.warpgroup.accumulator
type (fragmented results matrix)Here's an example usage: