Skip to content

[mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix #66455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 129 additions & 102 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"

#include <cassert>

namespace mlir {
//===----------------------------------------------------------------------===//
// Patterns and helpers used by both the KHR and the NV lowering paths.
//===----------------------------------------------------------------------===//

/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
/// Returns false if cannot.
Expand Down Expand Up @@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}

bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
assert(!operands.empty());
if (!llvm::all_equal(
llvm::map_range(operands, [](Value v) { return v.getType(); })))
return false;

return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
operands.front().getType());
}

namespace {
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaConstantOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Value cst = adaptor.getOperands().front();
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
return success();
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
return rewriter.notifyMatchFailure(op,
"not all operands are coop matrices");
}

auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

return success(
createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getOperands().size() != 2)
return failure();

// All operands should be of cooperative matrix types.
if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
return rewriter.notifyMatchFailure(op,
"not all operands are coop matrices");
}

if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
return failure();

// Use the original operands to check whether one of the operands is a splat
// scalar value.
Value lhs = op.getOperands().front();
Value rhs = op.getOperands().back();
Value splat = nullptr;
Value matrix = nullptr;
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
splat = adaptor.getOperands().front();
matrix = adaptor.getOperands().back();
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
matrix = adaptor.getOperands().front();
splat = adaptor.getOperands().back();
}
if (!splat || !matrix)
return rewriter.notifyMatchFailure(op, "no splat operand");

// Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
Value scalar;
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
if (!cc) {
return rewriter.notifyMatchFailure(op,
"splat is not a composite construct");
}

assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();

auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
op, coopType, ValueRange{matrix, scalar});
return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final
}
};

/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
/// ops.
struct WmmaConstantOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value cst = adaptor.getOperands()[0];
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
subgroupMmaConstantMatrixOp, coopType, cst);
return success();
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
adaptor.getOperands()));
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getOperands().size() != 2)
return failure();
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}

if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
return failure();

// Use the original operands to check whether one of the operands is a splat
// scalar value.
Value lhs = elementwiseOp.getOperands().front();
Value rhs = elementwiseOp.getOperands().back();
Value splat = nullptr;
Value matrix = nullptr;
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
splat = adaptor.getOperands().front();
matrix = adaptor.getOperands().back();
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
matrix = adaptor.getOperands().front();
splat = adaptor.getOperands().back();
}
if (!splat || !matrix)
return failure();

// Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
Value scalar = nullptr;
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
if (!cc)
return failure();
assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();

auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
elementwiseOp, coopType, ValueRange{matrix, scalar});
return success();
}
};

} // namespace
} // namespace nv
} // namespace mlir
Expand Down Expand Up @@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
/*benefit=*/2);
}

void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns
.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
context,
/*benefit=*/2);
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
/*benefit=*/2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,106 @@ module attributes {
-> !gpu.mma_matrix<16x16xf16, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_constant_op
gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16
%cst = arith.constant 1.0 : f16
// CHECK: %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] :
// CHECK-SAME: (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
%B: !gpu.mma_matrix<16x16xf16, "COp">,
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_elementwise addf %A, %B :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%D = gpu.subgroup_mma_elementwise negatef %C :
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%E = gpu.subgroup_mma_elementwise divf %D, %A :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
%F = gpu.subgroup_mma_elementwise extf %E :
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: %[[S:.+]]: f16
gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(
%A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 0 : index

%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
%C = gpu.subgroup_mma_elementwise mulf %A, %B :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>

// CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor>
%D = gpu.subgroup_mma_elementwise mulf %B, %C :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: %[[S:.+]]: f16
gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(
%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 0 : index

// CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_elementwise addf %A, %B :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">

// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}
}
}