Skip to content

Commit

Permalink
[LLVMGPU] Wmma layout for LLVMGPU vector distribute pipeline (#16928)
Browse files Browse the repository at this point in the history
This PR introduces WMMA layout on to LLVMGPU pipeline. The main changes
are actually not too big, surrounding introduction of the WMMA layouts,
data-duplicate to express duplication of data through modification of
the thread basis, and emitting WMMA intrinsic. Large portions of the
changes are generalizing the names of classes and variables to represent
that we are doing MMA in general and not mfma specific things since most
parts of the mfma layout work that we have done is reusable.

A todo that I plan to handle after this patch is to get layout for
16x16x16 with an FP16 accumulator since it has a weird requirement to
further interleave the output/C-matrix data since it still only does 8
elements per wmma instruction on C-Matrix, however we want to represent
it as 16 elements where index=0,2,..,14 holds the real value.
  • Loading branch information
raikonenfnu authored Apr 2, 2024
1 parent d1eef77 commit 2c88e49
Show file tree
Hide file tree
Showing 20 changed files with 804 additions and 203 deletions.
17 changes: 10 additions & 7 deletions compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,26 @@
namespace mlir::iree_compiler::IREE::HAL {

static ArrayAttr getMfmaArrayAttr(MLIRContext *context,
ArrayRef<IREE::GPU::MFMAIntrinsic> types) {
SmallVector<Attribute> attrs(types.size(), IREE::GPU::MFMAAttr());
ArrayRef<IREE::GPU::MMAIntrinsic> types) {
SmallVector<Attribute> attrs(types.size(), IREE::GPU::MMAAttr());
for (auto [idx, type] : llvm::enumerate(types)) {
attrs[idx] = IREE::GPU::MFMAAttr::get(context, type);
attrs[idx] = IREE::GPU::MMAAttr::get(context, type);
}
return ArrayAttr::get(context, attrs);
}

ArrayAttr getROCMSupportedMmaAttrs(MLIRContext *context, StringRef targetArch) {
if (targetArch == "gfx940" || targetArch == "gfx942") { // MI300A/X
return getMfmaArrayAttr(context,
{IREE::GPU::MFMAIntrinsic::F16_16x16x16_F32,
IREE::GPU::MFMAIntrinsic::F16_32x32x8_F32});
{IREE::GPU::MMAIntrinsic::MFMA_F16_16x16x16_F32,
IREE::GPU::MMAIntrinsic::MFMA_F16_32x32x8_F32});
} else if (targetArch == "gfx90a") { // MI210
return getMfmaArrayAttr(context,
{IREE::GPU::MFMAIntrinsic::F16_16x16x16_F32,
IREE::GPU::MFMAIntrinsic::F16_32x32x8_F32});
{IREE::GPU::MMAIntrinsic::MFMA_F16_16x16x16_F32,
IREE::GPU::MMAIntrinsic::MFMA_F16_32x32x8_F32});
} else if (targetArch == "gfx1100") { // RDNA3
return getMfmaArrayAttr(context,
{IREE::GPU::MMAIntrinsic::WMMA_F16_16x16x16_F32});
}
return ArrayAttr();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=MI300
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx942 %s | FileCheck %s --check-prefix=MI300
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx1100 %s | FileCheck %s --check-prefix=RDNA3

// MI300: mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>, #iree_gpu.mfma_layout<F16_32x32x8_F32>]
// MI300: mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>]
// RDNA3: mma_intrinsics = [#iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>]

stream.executable public @reduce_dispatch {
stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,23 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
auto mfmaAttr =
contractOp->getAttrOfType<IREE::GPU::MFMAAttr>("iree.amdgpu.mfma");
if (!mfmaAttr) {
auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mfma intrinsic attribute");
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}
// Get the storage vector types that each thread is in charge of.
auto [aVectorType, bVectorType, cVectorType] = mfmaAttr.getABCVectorTypes();
auto [aVectorType, bVectorType, cVectorType] = mmaAttr.getABCVectorTypes();
// Get parameters for the amdgpu.mfma operation.
MFMAParameters mfmaParams;
std::tie(mfmaParams.m, mfmaParams.n, mfmaParams.k) = mfmaAttr.getMNKShape();
mfmaParams.blocks = mfmaAttr.getBlockSize();
AMDMMAParameters mmaParams;
std::tie(mmaParams.m, mmaParams.n, mmaParams.k) = mmaAttr.getMNKShape();
mmaParams.blocks = mmaAttr.getBlockSize();
IREE::GPU::MMAComputeType computeType = mmaAttr.getComputeType();
if (computeType == IREE::GPU::MMAComputeType::INVALID) {
return rewriter.notifyMatchFailure(
contractOp, "Cannot determine intrinsic compute type.");
}

// Infer the contract kind so that we know know to correlate M/N/K dims.
VectorContractOpInfo opDetail(contractOp);
Expand Down Expand Up @@ -162,8 +167,9 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
rewriter.create<vector::ExtractOp>(loc, lhs, lhsBatchOffsets);
Value rhsSlice =
rewriter.create<vector::ExtractOp>(loc, rhs, rhsBatchOffsets);
accSlice = computeMMA(rewriter, loc, mfmaParams, lhsSlice, rhsSlice,
accSlice, aVectorType, bVectorType, cVectorType);
accSlice =
computeMMA(rewriter, loc, mmaParams, lhsSlice, rhsSlice, accSlice,
aVectorType, bVectorType, cVectorType, computeType);
}
finalTile = rewriter.create<vector::InsertOp>(loc, accSlice, finalTile,
resultBatchOffsets);
Expand Down Expand Up @@ -211,7 +217,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
applyPermutationToVector(rhsOffsets, rhsLayout.getBatchOrder());
}

struct MFMAParameters {
struct AMDMMAParameters {
uint32_t m = 0;
uint32_t n = 0;
uint32_t k = 0;
Expand All @@ -221,16 +227,21 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
// Generates amdgpu.mfma operation on the given inputs for the given MFMA
// |intrinsic|.
Value computeMMA(OpBuilder &builder, Location loc,
const MFMAParameters &mfmaParams, Value a, Value b, Value c,
VectorType aType, VectorType bType, VectorType cType) const {
const AMDMMAParameters &mmaParams, Value a, Value b, Value c,
VectorType aType, VectorType bType, VectorType cType,
IREE::GPU::MMAComputeType computeType) const {
Value aCast = builder.create<vector::ShapeCastOp>(a.getLoc(), aType, a);
Value bCast = builder.create<vector::ShapeCastOp>(b.getLoc(), bType, b);
Value cCast = builder.create<vector::ShapeCastOp>(c.getLoc(), cType, c);

Value mfmaOp = builder.create<amdgpu::MFMAOp>(
loc, cType, mfmaParams.m, mfmaParams.n, mfmaParams.k, mfmaParams.blocks,
aCast, bCast, cCast);
return builder.create<vector::ShapeCastOp>(c.getLoc(), c.getType(), mfmaOp);
Value mmaOp;
if (computeType == IREE::GPU::MMAComputeType::MFMA) {
mmaOp = builder.create<amdgpu::MFMAOp>(
loc, cType, mmaParams.m, mmaParams.n, mmaParams.k, mmaParams.blocks,
aCast, bCast, cCast);
} else if (computeType == IREE::GPU::MMAComputeType::WMMA) {
mmaOp = builder.create<amdgpu::WMMAOp>(loc, cType, aCast, bCast, cCast);
}
return builder.create<vector::ShapeCastOp>(c.getLoc(), c.getType(), mmaOp);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func.func @contract_to_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mfma = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_c,
Expand Down Expand Up @@ -133,7 +133,7 @@ func.func @contract_to_mfma_16x16x16_mm(%a : vector<16x16xf16>, %b : vector<16x1
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mfma = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_b,
Expand Down Expand Up @@ -225,7 +225,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch(%a : vector<64x8xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mfma = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_c,
Expand Down Expand Up @@ -318,7 +318,7 @@ func.func @contract_to_mfma_32x32x8_mm_kbatch(%a : vector<32x16xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mfma = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_c,
Expand Down Expand Up @@ -409,7 +409,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch_order(%a : vector<64x8xf16>, %b :
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mfma = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_c,
Expand Down Expand Up @@ -504,7 +504,7 @@ func.func @contract_to_mfma_32x32x8_mmt(%a : vector<32x8xf16>, %b : vector<64x8x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mfma = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_c,
Expand Down Expand Up @@ -532,3 +532,94 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1] : vector<4x1x1x4xf32> into vector<1x2x4x1x1x4xf32>
// CHECK: iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x1x4xf32> -> vector<32x64xf32>

// -----

// RDNA3 V_WMMA_F32_16X16X16_F32

#map1 = affine_map<(m, n, k) -> (m, k)>
#map2 = affine_map<(m, n, k) -> (k, n)>
#map3 = affine_map<(m, n, k) -> (m, n)>

// A: shape = 16x16, layout = layoutA
#layout_a = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 1],
outers_per_batch = [1, 1],
threads_per_outer = [16, 1],
elements_per_thread = [1, 16],

thread_order = [1, 0],

subgroup_basis = [1, 1],
thread_basis = [1, 32]
>

// B: shape = 16x16, layout = layoutB
#layout_b = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 1],
outers_per_batch = [1, 1],
threads_per_outer = [1, 16],
elements_per_thread = [16, 1],

element_order = [1, 0],

subgroup_basis = [1, 1],
thread_basis = [1, 32]
>

// C: shape = 16x16, layout = layoutC
#layout_c = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 1],
outers_per_batch = [8, 1],
threads_per_outer = [2, 16],
elements_per_thread = [1, 1],

element_order = [1, 0],

subgroup_basis = [1, 1],
thread_basis = [2, 16]
>

func.func @contract_to_wmma_16x16x16_mm(%a : vector<16x16xf16>, %b : vector<16x16xf16>, %c : vector<16x16xf32>) -> vector<16x16xf32> {
%output = vector.contract {
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F32>,
"__vector_layout_test_anchor_operand_0" = #layout_a,
"__vector_layout_test_anchor_operand_1" = #layout_b,
"__vector_layout_test_anchor_operand_2" = #layout_c,
"__vector_layout_test_anchor_result_0" = #layout_c
} %a, %b, %c : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
return %output : vector<16x16xf32>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}


// CHECK-LABEL: func.func @contract_to_wmma_16x16x16_mm
// CHECK-SAME: (%[[A:.+]]: vector<16x16xf16>, %[[B:.+]]: vector<16x16xf16>, %[[C:.+]]: vector<16x16xf32>)
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x8x1x1x1xf32>
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %[[C]] : vector<16x16xf32> -> vector<1x1x8x1x1x1xf32>
// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %[[A]] : vector<16x16xf16> -> vector<1x1x1x1x1x16xf16>
// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %[[B]] : vector<16x16xf16> -> vector<1x1x1x1x1x16xf16>
// CHECK: %[[C_VEC:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<8x1x1x1xf32> from vector<1x1x8x1x1x1xf32>
// CHECK: %[[A_VEC:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x16xf16> from vector<1x1x1x1x1x16xf16>
// CHECK: %[[B_VEC:.+]] = vector.extract %[[B_SIMT]][0, 0] : vector<1x1x1x16xf16> from vector<1x1x1x1x1x16xf16>
// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x16xf16> to vector<16xf16>
// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x1x16xf16> to vector<16xf16>
// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<8x1x1x1xf32> to vector<8xf32>
// CHECK: %[[WMMA:.+]] = amdgpu.wmma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[WMMA]] : vector<8xf32> to vector<8x1x1x1xf32>
// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %[[INIT]] [0, 0] : vector<8x1x1x1xf32> into vector<1x1x8x1x1x1xf32>
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
// CHECK: return {{.*}} %[[R_SIMD]]

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
// TRANSPOSE: return

hal.executable private @main_dispatch_0 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>, #iree_gpu.mfma_layout<F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>) {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>) {
hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info<LLVMGPUMatmulSimt, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>, workgroup_size = [64 : index, 16 : index, 1 : index]} {
^bb0(%arg0: !hal.device):
%c250 = arith.constant 250 : index
Expand Down
Loading

0 comments on commit 2c88e49

Please sign in to comment.