forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LLVMGPU] Introduce a pass that pad matmul to fit mma shapes. (iree-o…
…rg#17225) There are two modes in the pass, one is padding the parallel dimensions and the other is padding the reduction dimensions. The padding value is inferred from the producers that implements ValueBoundsOpInterface, i.e., they will be padded to the last tiling sizes.
- Loading branch information
Showing
9 changed files
with
345 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
177 changes: 177 additions & 0 deletions
177
compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" | ||
#include "iree/compiler/Codegen/LLVMGPU/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-llvmgpu-promote-matmul-to-fit-mma" | ||
|
||
namespace mlir::iree_compiler { | ||
#define GEN_PASS_DECL_LLVMGPUPROMOTEMATMULTOFITMMA | ||
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" | ||
namespace { | ||
|
||
class LLVMGPUPromoteMatmulToFitMMAPass | ||
: public LLVMGPUPromoteMatmulToFitMMABase< | ||
LLVMGPUPromoteMatmulToFitMMAPass> { | ||
public: | ||
explicit LLVMGPUPromoteMatmulToFitMMAPass( | ||
const LLVMGPUMatmulPadOption &option) { | ||
this->targetDimensions.setValue(option); | ||
} | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<tensor::TensorDialect, linalg::LinalgDialect>(); | ||
} | ||
|
||
void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op, | ||
utils::IteratorType targetIterType, bool nofold) const { | ||
LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n"); | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointAfter(op); | ||
|
||
SmallVector<int64_t> paddingDims; | ||
for (auto [index, iterType] : llvm::enumerate(op.getIteratorTypesArray())) { | ||
if (iterType == targetIterType) { | ||
paddingDims.push_back(index); | ||
} | ||
} | ||
|
||
SmallVector<bool> packPaddings(op.getNumDpsInputs(), nofold); | ||
|
||
// One is enough because they will essentially be padded to corresponding | ||
// tile sizes, which should be multiple of MMA shapes. | ||
SmallVector<int64_t> padToMultipleOf(paddingDims.size(), 1); | ||
SmallVector<Attribute> paddingValueAttributes; | ||
for (auto &operand : op->getOpOperands()) { | ||
auto elemType = getElementTypeOrSelf(operand.get().getType()); | ||
paddingValueAttributes.push_back(rewriter.getZeroAttr(elemType)); | ||
} | ||
|
||
auto options = | ||
linalg::LinalgPaddingOptions() | ||
.setPaddingDimensions(paddingDims) | ||
.setPaddingValues(paddingValueAttributes) | ||
.setPadToMultipleOf(padToMultipleOf) | ||
.setPackPaddings(packPaddings) | ||
.setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None); | ||
|
||
FailureOr<linalg::LinalgOp> result = | ||
linalg::padAndHoistLinalgOp(rewriter, op, options); | ||
if (failed(result)) { | ||
LLVM_DEBUG(llvm::dbgs() << "failed to pad op " << op << "\n"); | ||
} | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *ctx = &getContext(); | ||
auto funcOp = getOperation(); | ||
|
||
// Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so | ||
// we can kick canonicalization patterns to fold outer tensor.pad ops away. | ||
bool nofold = false; | ||
utils::IteratorType targetIterType = utils::IteratorType::parallel; | ||
switch (targetDimensions) { | ||
case LLVMGPUMatmulPadOption::ParallelDims: | ||
LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n"); | ||
targetIterType = utils::IteratorType::parallel; | ||
nofold = false; | ||
break; | ||
case LLVMGPUMatmulPadOption::ReductionDims: | ||
LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n"); | ||
targetIterType = utils::IteratorType::reduction; | ||
nofold = true; | ||
break; | ||
default: // Unreachable. | ||
assert(false); | ||
break; | ||
}; | ||
|
||
SmallVector<linalg::LinalgOp> candidates; | ||
funcOp->walk([&](linalg::LinalgOp op) { | ||
if (linalg::isaContractionOpInterface(op)) { | ||
candidates.push_back(op); | ||
} | ||
}); | ||
|
||
IRRewriter rewriter(ctx); | ||
for (auto op : candidates) { | ||
padWithZeroValue(rewriter, op, targetIterType, nofold); | ||
} | ||
|
||
{ | ||
RewritePatternSet patterns(ctx); | ||
linalg::populateSwapExtractSliceWithFillPatterns(patterns); | ||
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); | ||
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); | ||
ctx->getLoadedDialect<tensor::TensorDialect>() | ||
->getCanonicalizationPatterns(patterns); | ||
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { | ||
LLVM_DEBUG(llvm::dbgs() << "----- cleanup failed -----\n"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
// XXX(hanchung): This is needed for pad op fusion, which will remove | ||
// outer pad ops. I.e., it mainly wants to remove first pad op in the | ||
// pad->extract_slice->pad chain, while the canonicalization pattern can | ||
// only recognize slice->pad->slice->pad. | ||
{ | ||
SmallVector<tensor::PadOp> padOps; | ||
funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); }); | ||
for (auto op : padOps) { | ||
auto srcExtractSliceOp = | ||
op.getSource().getDefiningOp<tensor::ExtractSliceOp>(); | ||
if (!srcExtractSliceOp) { | ||
continue; | ||
} | ||
auto producerPadOp = | ||
srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>(); | ||
if (!producerPadOp) { | ||
continue; | ||
} | ||
auto src = producerPadOp.getSource() | ||
.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>(); | ||
if (!src) { | ||
continue; | ||
} | ||
|
||
rewriter.setInsertionPointAfter(src); | ||
SmallVector<OpFoldResult> sizes = | ||
tensor::getMixedSizes(rewriter, op.getLoc(), src); | ||
SmallVector<OpFoldResult> offsets(sizes.size(), | ||
rewriter.getIndexAttr(0)); | ||
SmallVector<OpFoldResult> strides(sizes.size(), | ||
rewriter.getIndexAttr(1)); | ||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( | ||
op.getLoc(), src.getResult(), offsets, sizes, strides); | ||
rewriter.startOpModification(op); | ||
op.getSourceMutable().assign(extractSliceOp.getResult()); | ||
rewriter.finalizeOpModification(op); | ||
} | ||
|
||
RewritePatternSet patterns(ctx); | ||
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> | ||
createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) { | ||
return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option); | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.