Skip to content

Commit

Permalink
[LLVMGPU] Introduce a pass that pad matmul to fit mma shapes. (iree-o…
Browse files Browse the repository at this point in the history
…rg#17225)

There are two modes in the pass, one is padding the parallel dimensions
and the other is padding the reduction dimensions. The padding value is
inferred from the producers that implements ValueBoundsOpInterface,
i.e., they will be padded to the last tiling sizes.
  • Loading branch information
hanhanW authored May 1, 2024
1 parent b8ef25c commit e6d8aa7
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ iree_compiler_cc_library(
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPrefetching.cpp",
"LLVMGPUPromoteMatmulToFitMMA.cpp",
"LLVMGPUSelectLoweringStrategy.cpp",
"LLVMGPUTensorCoreVectorization.cpp",
"LLVMGPUTensorPad.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ iree_cc_library(
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPrefetching.cpp"
"LLVMGPUPromoteMatmulToFitMMA.cpp"
"LLVMGPUSelectLoweringStrategy.cpp"
"LLVMGPUTensorCoreVectorization.cpp"
"LLVMGPUTensorPad.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-llvmgpu-promote-matmul-to-fit-mma"

namespace mlir::iree_compiler {
#define GEN_PASS_DECL_LLVMGPUPROMOTEMATMULTOFITMMA
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
namespace {

class LLVMGPUPromoteMatmulToFitMMAPass
: public LLVMGPUPromoteMatmulToFitMMABase<
LLVMGPUPromoteMatmulToFitMMAPass> {
public:
explicit LLVMGPUPromoteMatmulToFitMMAPass(
const LLVMGPUMatmulPadOption &option) {
this->targetDimensions.setValue(option);
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, linalg::LinalgDialect>();
}

void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
utils::IteratorType targetIterType, bool nofold) const {
LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);

SmallVector<int64_t> paddingDims;
for (auto [index, iterType] : llvm::enumerate(op.getIteratorTypesArray())) {
if (iterType == targetIterType) {
paddingDims.push_back(index);
}
}

SmallVector<bool> packPaddings(op.getNumDpsInputs(), nofold);

// One is enough because they will essentially be padded to corresponding
// tile sizes, which should be multiple of MMA shapes.
SmallVector<int64_t> padToMultipleOf(paddingDims.size(), 1);
SmallVector<Attribute> paddingValueAttributes;
for (auto &operand : op->getOpOperands()) {
auto elemType = getElementTypeOrSelf(operand.get().getType());
paddingValueAttributes.push_back(rewriter.getZeroAttr(elemType));
}

auto options =
linalg::LinalgPaddingOptions()
.setPaddingDimensions(paddingDims)
.setPaddingValues(paddingValueAttributes)
.setPadToMultipleOf(padToMultipleOf)
.setPackPaddings(packPaddings)
.setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);

FailureOr<linalg::LinalgOp> result =
linalg::padAndHoistLinalgOp(rewriter, op, options);
if (failed(result)) {
LLVM_DEBUG(llvm::dbgs() << "failed to pad op " << op << "\n");
}
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

// Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
// we can kick canonicalization patterns to fold outer tensor.pad ops away.
bool nofold = false;
utils::IteratorType targetIterType = utils::IteratorType::parallel;
switch (targetDimensions) {
case LLVMGPUMatmulPadOption::ParallelDims:
LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
targetIterType = utils::IteratorType::parallel;
nofold = false;
break;
case LLVMGPUMatmulPadOption::ReductionDims:
LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
targetIterType = utils::IteratorType::reduction;
nofold = true;
break;
default: // Unreachable.
assert(false);
break;
};

SmallVector<linalg::LinalgOp> candidates;
funcOp->walk([&](linalg::LinalgOp op) {
if (linalg::isaContractionOpInterface(op)) {
candidates.push_back(op);
}
});

IRRewriter rewriter(ctx);
for (auto op : candidates) {
padWithZeroValue(rewriter, op, targetIterType, nofold);
}

{
RewritePatternSet patterns(ctx);
linalg::populateSwapExtractSliceWithFillPatterns(patterns);
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
ctx->getLoadedDialect<tensor::TensorDialect>()
->getCanonicalizationPatterns(patterns);
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
LLVM_DEBUG(llvm::dbgs() << "----- cleanup failed -----\n");
return signalPassFailure();
}
}

// XXX(hanchung): This is needed for pad op fusion, which will remove
// outer pad ops. I.e., it mainly wants to remove first pad op in the
// pad->extract_slice->pad chain, while the canonicalization pattern can
// only recognize slice->pad->slice->pad.
{
SmallVector<tensor::PadOp> padOps;
funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); });
for (auto op : padOps) {
auto srcExtractSliceOp =
op.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!srcExtractSliceOp) {
continue;
}
auto producerPadOp =
srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>();
if (!producerPadOp) {
continue;
}
auto src = producerPadOp.getSource()
.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!src) {
continue;
}

rewriter.setInsertionPointAfter(src);
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, op.getLoc(), src);
SmallVector<OpFoldResult> offsets(sizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(sizes.size(),
rewriter.getIndexAttr(1));
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
op.getLoc(), src.getResult(), offsets, sizes, strides);
rewriter.startOpModification(op);
op.getSourceMutable().assign(extractSliceOp.getResult());
rewriter.finalizeOpModification(op);
}

RewritePatternSet patterns(ctx);
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) {
return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
}

} // namespace mlir::iree_compiler
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_PASS_DETAIL_H_
#define IREE_COMPILER_CODEGEN_LLVMGPU_PASS_DETAIL_H_

#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ createLLVMGPUPackSharedMemoryAlloc();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUPrefetchSharedMemoryPass();

/// Pass to pad operations on tensors in top-down order.
enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims };
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUPromoteMatmulToFitMMAPass(
LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims);

enum class GPUTensorCoreType {
WMMA = 0,
MMA_SYNC = 1,
Expand Down
19 changes: 19 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ def LLVMGPUPrefetchSharedMemory :
let constructor = "mlir::iree_compiler::createLLVMGPUPrefetchSharedMemoryPass()";
}

def LLVMGPUPromoteMatmulToFitMMA :
InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> {
let summary = "Pass to promote contraction ops to fit mma shapes";
let constructor = "mlir::iree_compiler::createLLVMGPUPromoteMatmulToFitMMAPass()";
let options = [
Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption",
/*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims",
"Select the strategy to control how multi_reduction is lowered.",
[{::llvm::cl::values(
clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims,
"parallel",
"Pad all the parallel dims for contraction ops."),
clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims,
"reduction",
"Pad all the reduction dims for contraction ops.")
)}]>
];
}

def LLVMGPUSelectLoweringStrategy :
Pass<"iree-llvmgpu-select-lowering-strategy", "ModuleOp"> {
let summary = "Select a IREE::HAL::DispatchLoweringPassPipeline for lowering the target variant";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ iree_lit_test_suite(
"pack_pipeline_test.mlir",
"pack_shared_memory_alloc.mlir",
"prefetch_shared_memory.mlir",
"promote_matmul_to_fit_mma.mlir",
"tensor_pad.mlir",
"tensorcore_vectorization.mlir",
"transform_dialect_bufferize.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_lit_test_suite(
"pack_pipeline_test.mlir"
"pack_shared_memory_alloc.mlir"
"prefetch_shared_memory.mlir"
"promote_matmul_to_fit_mma.mlir"
"reduction_pipeline_cuda.mlir"
"reduction_pipeline_rocm.mlir"
"reduction_pipeline_transform_cuda.mlir"
Expand Down
Loading

0 comments on commit e6d8aa7

Please sign in to comment.