Skip to content

Commit

Permalink
Clean up encoding-related code. NFC. (#19717)
Browse files Browse the repository at this point in the history
Fixing misc issues before modifying the surrounding code.

Signed-off-by: Jakub Kuderski <jakub@nod-labs.com>
  • Loading branch information
kuhar authored Jan 16, 2025
1 parent 3032df2 commit 5ee9b27
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 49 deletions.
24 changes: 9 additions & 15 deletions compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,27 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

#define DEBUG_TYPE "iree-codegen--materialize-encoding"
#define DEBUG_TYPE "iree-codegen-materialize-encoding"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

Expand Down Expand Up @@ -150,17 +147,16 @@ getFuncExecutableTargetAttrs(FunctionOpInterface funcOp,
return executableTargetAttrs;
}

struct MaterializeHostEncodingPass
: public impl::MaterializeHostEncodingPassBase<
MaterializeHostEncodingPass> {
struct MaterializeHostEncodingPass final
: impl::MaterializeHostEncodingPassBase<MaterializeHostEncodingPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, tensor::TensorDialect,
IREE::Codegen::IREECodegenDialect,
IREE::CPU::IREECPUDialect, IREE::GPU::IREEGPUDialect>();
}

void runOnOperation() override {
auto moduleOp = getOperation();
ModuleOp moduleOp = getOperation();

// Run required analysis passes.
IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
Expand Down Expand Up @@ -211,11 +207,9 @@ struct MaterializeHostEncodingPass
// that. It should _not_ be running on both - target-specific codegen passes
// are not allowed on host programs and it's a big violation of layering that
// this exists.
struct MaterializeDeviceEncodingPass
: public impl::MaterializeDeviceEncodingPassBase<
MaterializeDeviceEncodingPass> {
using impl::MaterializeDeviceEncodingPassBase<
MaterializeDeviceEncodingPass>::MaterializeDeviceEncodingPassBase;
struct MaterializeDeviceEncodingPass final
: impl::MaterializeDeviceEncodingPassBase<MaterializeDeviceEncodingPass> {
using Base::Base;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, tensor::TensorDialect,
Expand All @@ -224,7 +218,7 @@ struct MaterializeDeviceEncodingPass
}

void runOnOperation() override {
auto funcOp = getOperation();
FunctionOpInterface funcOp = getOperation();
auto executableTargetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
if (failed(materializeFuncOpEncodings(funcOp, executableTargetAttr,
testCLGPUTarget))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -36,7 +34,7 @@ struct MaterializeEncodingIntoNopPass final

void runOnOperation() override {
MLIRContext *context = &getContext();
auto operation = getOperation();
FunctionOpInterface operation = getOperation();

auto materializeEncodingValueFn =
[](RankedTensorType, OpBuilder &,
Expand Down
32 changes: 15 additions & 17 deletions compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -87,14 +86,14 @@ static bool hasMatmulLikeBody(linalg::LinalgOp linalgOp) {
if (!yieldOp) {
return false;
}
auto addOp = yieldOp->getOperand(0).getDefiningOp();
Operation *addOp = yieldOp->getOperand(0).getDefiningOp();
if (!addOp || !isa<arith::AddIOp, arith::AddFOp>(addOp)) {
return false;
}
auto addLhs = addOp->getOperand(0);
auto addRhs = addOp->getOperand(1);
auto addLhsOp = addLhs.getDefiningOp();
auto addRhsOp = addRhs.getDefiningOp();
Value addLhs = addOp->getOperand(0);
Value addRhs = addOp->getOperand(1);
Operation *addLhsOp = addLhs.getDefiningOp();
Operation *addRhsOp = addRhs.getDefiningOp();
if (!(addLhsOp && addRhs == outBlockArg) &&
!(addRhsOp && addLhs == outBlockArg)) {
return false;
Expand All @@ -103,8 +102,8 @@ static bool hasMatmulLikeBody(linalg::LinalgOp linalgOp) {
if (!isa<arith::MulFOp, arith::MulIOp>(mulOp)) {
return false;
}
auto mulLhs = mulOp->getOperand(0);
auto mulRhs = mulOp->getOperand(1);
Value mulLhs = mulOp->getOperand(0);
Value mulRhs = mulOp->getOperand(1);
auto mulLhsOp = mulLhs.getDefiningOp<CastOpInterface>();
auto mulRhsOp = mulRhs.getDefiningOp<CastOpInterface>();
if (!isa<BlockArgument>(mulLhs) && !mulLhsOp && !isa<BlockArgument>(mulRhs) &&
Expand Down Expand Up @@ -155,11 +154,11 @@ static LogicalResult isSupportedContractionOp(PatternRewriter &rewriter,

namespace {

class setContractionOpEncoding
class SetContractionOpEncoding final
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
public:
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
explicit setContractionOpEncoding(MLIRContext *ctx, int64_t factor)
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
explicit SetContractionOpEncoding(MLIRContext *ctx, int64_t factor)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), padFactor(factor) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
Expand Down Expand Up @@ -244,8 +243,8 @@ class setContractionOpEncoding
/// Pattern to fold a `linalg.fill` -> `iree_encoding.set_encoding`
/// operation into a `linalg.fill` of the encoded type.
struct FoldFillWithSetEncoding final
: public OpRewritePattern<IREE::Encoding::SetEncodingOp> {
using OpRewritePattern<IREE::Encoding::SetEncodingOp>::OpRewritePattern;
: OpRewritePattern<IREE::Encoding::SetEncodingOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp,
PatternRewriter &rewriter) const override {
Expand All @@ -267,15 +266,14 @@ struct FoldFillWithSetEncoding final
}
};

struct SetEncodingPass final
: public impl::SetEncodingPassBase<SetEncodingPass> {
struct SetEncodingPass final : impl::SetEncodingPassBase<SetEncodingPass> {
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<setContractionOpEncoding>(context, padFactor);
patterns.add<SetContractionOpEncoding>(context, padFactor);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldFillWithSetEncoding>(context);
patterns.add<FoldFillWithSetEncoding>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,9 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"

namespace mlir::iree_compiler::GlobalOptimization {
Expand All @@ -27,7 +19,7 @@ namespace mlir::iree_compiler::GlobalOptimization {
// path. This is mainly for testing.
static llvm::cl::opt<bool> clEnableExperimentalRocmDataTiling(
"iree-global-opt-experimental-rocm-data-tiling",
llvm::cl::desc("Enables data-tiling materializatino for rocm backends "
llvm::cl::desc("Enables data-tiling materialization for rocm backends "
"(experimental)."),
llvm::cl::init(false));

Expand All @@ -38,10 +30,9 @@ using FunctionLikeNest =
MultiOpNest<IREE::Util::InitializerOp, IREE::Util::FuncOp>;

namespace {
class MaterializeHomogeneousEncodingsPass
: public impl::MaterializeHomogeneousEncodingsPassBase<
struct MaterializeHomogeneousEncodingsPass final
: impl::MaterializeHomogeneousEncodingsPassBase<
MaterializeHomogeneousEncodingsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::HAL::HALDialect, tensor::TensorDialect,
IREE::Codegen::IREECodegenDialect>();
Expand Down Expand Up @@ -72,7 +63,7 @@ class MaterializeHomogeneousEncodingsPass
// TODO: vmvx has its own logic about supporting dynamic tile
// sizes. It is not fully integrated into the pipeline, so we remain the
// materialization to the end.
auto executableTarget = executableTargets[0];
IREE::HAL::ExecutableTargetAttr executableTarget = executableTargets[0];
if (executableTarget.getBackend() == "vmvx") {
return;
}
Expand Down

0 comments on commit 5ee9b27

Please sign in to comment.