Skip to content

[mlir][spirv][gpu] Add conversion for load/store/mad coop matrix ops #66311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ class MMAMatrixType;
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);

/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
/// using the KHR Cooperative Matrix extension.
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);

/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
/// using the NV Cooperative Matrix extension.
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);

/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
/// `type`.
spirv::CooperativeMatrixType
convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);

/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
/// `type`.
spirv::CooperativeMatrixNVType
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
let options = [
Option<"use64bitIndex", "use-64bit-index",
"bool", /*default=*/"false",
"Use 64-bit integers to convert index types">
"Use 64-bit integers to convert index types">,
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
"bool", /*default=*/"true",
"Use the NV cooperative matrix extension insted of the KHR extension"
" to lower GPU WMMA ops">,
];
}

Expand Down
25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);

let builders = [
OpBuilder<(ins "Type":$result, "Value":$pointer,
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, result, pointer, layout, stride,
spirv::MemoryAccessAttr{});
}]>
];
}

// -----
Expand Down Expand Up @@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
);

let results = (outs);

let builders = [
OpBuilder<(ins "Value":$pointer, "Value":$object,
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, pointer, object, layout, stride,
spirv::MemoryAccessAttr{});
}]>
];
}

// -----
Expand Down Expand Up @@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);

let builders = [
OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
build($_builder, $_state, a, b, c,
spirv::CooperativeMatrixOperandsKHRAttr{});
}]>
];
}

//===----------------------------------------------------------------------===//
Expand Down
20 changes: 16 additions & 4 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToSPIRVCoopMatrixNVType(type);

typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
gpu::MMAMatrixType type) -> Type {
if (useNV)
return convertMMAToSPIRVCoopMatrixNVType(type);

return convertMMAToSPIRVCoopMatrixType(type);
});

RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
patterns);
if (this->useCoopMatrixNV) {
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
patterns);
} else {
populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
patterns);
}

// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
Expand Down
145 changes: 140 additions & 5 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,28 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/StringSwitch.h"

namespace mlir::nv {
namespace {
#include <cassert>

namespace mlir {
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
/// Returns false if cannot.
///
/// See SPV_NV_cooperative_matrix for supported elementwise ops.
static bool createElementwiseOp(ConversionPatternRewriter &builder,
gpu::SubgroupMmaElementwiseOp op,
spirv::CooperativeMatrixNVType coopType,
gpu::SubgroupMmaElementwiseOp op, Type coopType,
ValueRange operands) {
assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
coopType)));

switch (op.getOpType()) {
case gpu::MMAElementwiseOp::ADDF:
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
Expand Down Expand Up @@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}

//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix
//===----------------------------------------------------------------------===//

namespace khr {
namespace {

/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
/// dialect.
struct WmmaLoadOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Location loc = op->getLoc();

auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
MemRefType memrefType = op.getSrcMemref().getType();
Value bufferPtr =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), loc, rewriter);

auto coopType =
typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));

bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
: spirv::CooperativeMatrixLayoutKHR::RowMajor;

rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
op, coopType, bufferPtr, strideValue, layout);
return success();
}
};

/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
/// dialect.
struct WmmaStoreOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Location loc = op->getLoc();

auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
Value bufferPtr =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
adaptor.getIndices(), loc, rewriter);

int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));

bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
: spirv::CooperativeMatrixLayoutKHR::RowMajor;

rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
op, bufferPtr, adaptor.getSrc(), strideValue, layout);
return success();
}
};

/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
/// dialect.
struct WmmaMmaOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaComputeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
adaptor.getOpC());
return success();
}
};

} // namespace
} // namespace khr

//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix
//===----------------------------------------------------------------------===//

namespace nv {
namespace {

/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
/// dialect.
struct WmmaLoadOpToSPIRVLowering final
Expand Down Expand Up @@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
};

} // namespace
} // namespace mlir::nv
} // namespace nv
} // namespace mlir

mlir::spirv::CooperativeMatrixNVType
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
Expand All @@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
}

mlir::spirv::CooperativeMatrixType
mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
ArrayRef<int64_t> retTypeShape = type.getShape();
Type elementType = type.getElementType();

auto use =
llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
.Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
.Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
.Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);

return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
retTypeShape[1],
spirv::Scope::Subgroup, use);
}

void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
}

void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
[Shader, CooperativeMatrixKHR, Float16],
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
#spirv.resource_limits<>>} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_load_op
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 16 : index
%j = arith.constant 16 : index
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">

// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_store_op
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
%arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 16 : index
%j = arith.constant 16 : index
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>

// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
%B: !gpu.mma_matrix<16x16xf16, "BOp">,
%C: !gpu.mma_matrix<16x16xf16, "COp">,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
// CHECK-SAME: -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
!gpu.mma_matrix<16x16xf16, "BOp">
-> !gpu.mma_matrix<16x16xf16, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s

module attributes {
gpu.container_module,
Expand Down