Skip to content

Commit

Permalink
[mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (#78413)
Browse files Browse the repository at this point in the history
This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When
the destionation memref and current accumulator matrix were small, the
previous code was reaching out of range.
  • Loading branch information
grypp authored Jan 22, 2024
1 parent a31a600 commit 21830c9
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 10 deletions.
32 changes: 22 additions & 10 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,12 +1548,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};

Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);

auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
Expand All @@ -1566,16 +1560,34 @@ struct NVGPUWarpgroupMmaStoreOpLowering
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
};

Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);

Value tj = makeMul(lane4modId, c2);
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
for (int i = 0; i < 2; ++i) {

auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();

// Number of 32-bit registers owns per thread
constexpr unsigned numAdjacentRegisters = 2;
// Number of 8x8 matrices one below another per warp
constexpr unsigned numStackedMatrices = 2;

size_t storeCount = (structType.getBody().size() /
(numStackedMatrices * numAdjacentRegisters));

for (size_t i = 0; i < numStackedMatrices; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
for (int j = 0; j < 16; ++j) {
for (size_t j = 0; j < storeCount; ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
int sIndex = i * 2 + j * 4;
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
size_t structIndex = (i * numAdjacentRegisters) +
(j * (numStackedMatrices * numAdjacentRegisters));
makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
}
}
}
Expand Down
130 changes: 130 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,136 @@ func.func @warpgroup_mma_store(
return
}

// CHECK-LABEL: @warpgroup_mma_store_multiple
func.func @warpgroup_mma_store_multiple(
%shmem_m64n8k : memref<64x8xf32>,
%shmem_m64n16k : memref<64x16xf32>,
%shmem_m64n24k : memref<64x24xf32>,
%shmem_m64n32k : memref<64x32xf32>,
%shmem_m64n40k : memref<64x40xf32>,
%shmem_m64n48k : memref<64x48xf32>,
%shmem_m64n56k : memref<64x56xf32>,
%shmem_m64n64k : memref<64x64xf32>,
%shmem_m64n72k : memref<64x72xf32>,
%shmem_m64n80k : memref<64x80xf32>,
%shmem_m64n88k : memref<64x88xf32>,
%shmem_m64n96k : memref<64x96xf32>,
%shmem_m64n104k : memref<64x104xf32>,
%shmem_m64n112k : memref<64x112xf32>,
%shmem_m64n120k : memref<64x120xf32>,
%shmem_m64n128k : memref<64x128xf32>,
%shmem_m64n136k : memref<64x136xf32>,
%shmem_m64n144k : memref<64x144xf32>,
%shmem_m64n152k : memref<64x152xf32>,
%shmem_m64n160k : memref<64x160xf32>,
%shmem_m64n168k : memref<64x168xf32>,
%shmem_m64n176k : memref<64x176xf32>,
%shmem_m64n184k : memref<64x184xf32>,
%shmem_m64n192k : memref<64x192xf32>,
%shmem_m64n200k : memref<64x200xf32>,
%shmem_m64n208k : memref<64x208xf32>,
%shmem_m64n216k : memref<64x216xf32>,
%shmem_m64n224k : memref<64x224xf32>,
%shmem_m64n232k : memref<64x232xf32>,
%shmem_m64n240k : memref<64x240xf32>,
%shmem_m64n248k : memref<64x248xf32>,
%shmem_m64n256k : memref<64x256xf32>,
%res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
%res_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>>,
%res_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
%res_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>>,
%res_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>>,
%res_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>>,
%res_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
%res_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>>,
%res_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>>,
%res_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>>,
%res_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>>,
%res_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>>,
%res_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>>,
%res_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>>,
%res_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
%res_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>>,
%res_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>>,
%res_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>>,
%res_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>>,
%res_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>>,
%res_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>>,
%res_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>>,
%res_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>>,
%res_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>>,
%res_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>>,
%res_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>>,
%res_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>>,
%res_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>>,
%res_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>>,
%res_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>>,
%res_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>>) {
// CHECK-COUNT-8: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32>
// CHECK-COUNT-12: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x24xf32>
// CHECK-COUNT-16: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x32xf32>
// CHECK-COUNT-20: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x40xf32>
// CHECK-COUNT-24: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x48xf32>
// CHECK-COUNT-28: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x56xf32>
// CHECK-COUNT-32: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf32>
// CHECK-COUNT-36: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x72xf32>
// CHECK-COUNT-40: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x80xf32>
// CHECK-COUNT-44: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x88xf32>
// CHECK-COUNT-48: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x96xf32>
// CHECK-COUNT-52: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x104xf32>
// CHECK-COUNT-56: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x112xf32>
// CHECK-COUNT-60: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x120xf32>
// CHECK-COUNT-64: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x128xf32>
// CHECK-COUNT-68: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x136xf32>
// CHECK-COUNT-72: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x144xf32>
// CHECK-COUNT-76: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x152xf32>
// CHECK-COUNT-80: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x160xf32>
// CHECK-COUNT-84: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x168xf32>
// CHECK-COUNT-88: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x176xf32>
// CHECK-COUNT-92: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x184xf32>
// CHECK-COUNT-96: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x192xf32>
// CHECK-COUNT-100: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x200xf32>
// CHECK-COUNT-104: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x208xf32>
// CHECK-COUNT-108: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x216xf32>
// CHECK-COUNT-112: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x224xf32>
// CHECK-COUNT-116: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x232xf32>
// CHECK-COUNT-120: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x240xf32>
// CHECK-COUNT-124: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x248xf32>
// CHECK-COUNT-128: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x256xf32>
nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32>
nvgpu.warpgroup.mma.store %res_m64n24k, %shmem_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>> to memref<64x24xf32>
nvgpu.warpgroup.mma.store %res_m64n32k, %shmem_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>> to memref<64x32xf32>
nvgpu.warpgroup.mma.store %res_m64n40k, %shmem_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>> to memref<64x40xf32>
nvgpu.warpgroup.mma.store %res_m64n48k, %shmem_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>> to memref<64x48xf32>
nvgpu.warpgroup.mma.store %res_m64n56k, %shmem_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>> to memref<64x56xf32>
nvgpu.warpgroup.mma.store %res_m64n64k, %shmem_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>> to memref<64x64xf32>
nvgpu.warpgroup.mma.store %res_m64n72k, %shmem_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>> to memref<64x72xf32>
nvgpu.warpgroup.mma.store %res_m64n80k, %shmem_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>> to memref<64x80xf32>
nvgpu.warpgroup.mma.store %res_m64n88k, %shmem_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>> to memref<64x88xf32>
nvgpu.warpgroup.mma.store %res_m64n96k, %shmem_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>> to memref<64x96xf32>
nvgpu.warpgroup.mma.store %res_m64n104k, %shmem_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>> to memref<64x104xf32>
nvgpu.warpgroup.mma.store %res_m64n112k, %shmem_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>> to memref<64x112xf32>
nvgpu.warpgroup.mma.store %res_m64n120k, %shmem_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>> to memref<64x120xf32>
nvgpu.warpgroup.mma.store %res_m64n128k, %shmem_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to memref<64x128xf32>
nvgpu.warpgroup.mma.store %res_m64n136k, %shmem_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>> to memref<64x136xf32>
nvgpu.warpgroup.mma.store %res_m64n144k, %shmem_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>> to memref<64x144xf32>
nvgpu.warpgroup.mma.store %res_m64n152k, %shmem_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>> to memref<64x152xf32>
nvgpu.warpgroup.mma.store %res_m64n160k, %shmem_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>> to memref<64x160xf32>
nvgpu.warpgroup.mma.store %res_m64n168k, %shmem_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>> to memref<64x168xf32>
nvgpu.warpgroup.mma.store %res_m64n176k, %shmem_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>> to memref<64x176xf32>
nvgpu.warpgroup.mma.store %res_m64n184k, %shmem_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>> to memref<64x184xf32>
nvgpu.warpgroup.mma.store %res_m64n192k, %shmem_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>> to memref<64x192xf32>
nvgpu.warpgroup.mma.store %res_m64n200k, %shmem_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>> to memref<64x200xf32>
nvgpu.warpgroup.mma.store %res_m64n208k, %shmem_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>> to memref<64x208xf32>
nvgpu.warpgroup.mma.store %res_m64n216k, %shmem_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>> to memref<64x216xf32>
nvgpu.warpgroup.mma.store %res_m64n224k, %shmem_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>> to memref<64x224xf32>
nvgpu.warpgroup.mma.store %res_m64n232k, %shmem_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>> to memref<64x232xf32>
nvgpu.warpgroup.mma.store %res_m64n240k, %shmem_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>> to memref<64x240xf32>
nvgpu.warpgroup.mma.store %res_m64n248k, %shmem_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>> to memref<64x248xf32>
nvgpu.warpgroup.mma.store %res_m64n256k, %shmem_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>> to memref<64x256xf32>
return
}

func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
Expand Down

0 comments on commit 21830c9

Please sign in to comment.