Skip to content

Commit

Permalink
[AMD] Use Linear Layout convertions for AMDWmma (#5255)
Browse files Browse the repository at this point in the history
Enable LL conwertions for WMMA as well as for MFMA layouts.

See also: #5210

Signed-off-by: Ilya Veselov <iveselov.nn@gmail.com>
  • Loading branch information
joviliast authored Nov 28, 2024
1 parent 55b741d commit 1cb0d99
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
// completed before we can remove the layoutIsOK check:
// 1. Support for AMD's WMMA
// 1. Support for AMD's WMMA dot operand
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
if (isa<MmaEncodingTrait>(layout)) {
return !useLegacyMMAConversion;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
Expand Down
65 changes: 65 additions & 0 deletions test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}>
Expand Down Expand Up @@ -97,6 +98,70 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
tt.return
}

// CHECK-LABEL: blocked_to_wmma1
tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1>
tt.return
}

// CHECK-LABEL: slice_blocked_to_wmma1
tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>
tt.return
}

// CHECK-LABEL: wmma1_to_blocked
tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) {
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked>
tt.return
}

// CHECK-LABEL: slice_wmma1_to_blocked
tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) {
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
tt.return
}

// CHECK-LABEL: blocked_to_wmma2
tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2>
tt.return
}

// CHECK-LABEL: slice_blocked_to_wmma2
tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>
tt.return
}

// CHECK-LABEL: wmma2_to_blocked
tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) {
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked>
tt.return
}

// CHECK-LABEL: slice_wmma2_to_blocked
tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>) {
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
tt.return
}
}

// -----
Expand Down

0 comments on commit 1cb0d99

Please sign in to comment.