Skip to content

Commit

Permalink
Move EncodingAttr and related ops from LinalgExt to a new `Encodi…
Browse files Browse the repository at this point in the history
…ng` dialect (iree-org#17277)

Fixes iree-org#17191.
  • Loading branch information
bjacob authored May 6, 2024
1 parent c2114b8 commit 792f14d
Show file tree
Hide file tree
Showing 76 changed files with 2,224 additions and 1,650 deletions.
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
Expand Down Expand Up @@ -244,6 +245,7 @@ iree_compiler_cc_library(
# Dialects
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions:LinalgExtExtensions",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ iree_cc_library(
iree::compiler::Codegen::Interfaces::UKernelOpInterface
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
Expand Down Expand Up @@ -262,6 +263,7 @@ iree_cc_library(
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::TransformStrategies::Common::TransformStrategies
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::TransformExtensions::FlowExtensions
iree::compiler::Dialect::LinalgExt::IR
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ iree_cc_library(
iree::compiler::Codegen::Interfaces::UKernelOpInterface
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
Expand Down Expand Up @@ -474,11 +474,11 @@ matchDAGForUKernel(RewriterBase &rewriter, tensor::UnPackOp op,
}

static uint32_t
getFlagForUserAndOperandTypes(IREE::LinalgExt::EncodingAttr encoding,
getFlagForUserAndOperandTypes(IREE::Encoding::EncodingAttr encoding,
ArrayRef<Attribute> operandTypes) {
// There are currently no batch_mmt4d ukernels, so check for no batch
// dimension.
auto cDims = IREE::LinalgExt::getEncodingContractionDims(encoding);
auto cDims = IREE::Encoding::getEncodingContractionDims(encoding);
if (failed(cDims) || !cDims->batch.empty() || operandTypes.size() != 3) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_NONE;
}
Expand All @@ -505,13 +505,13 @@ getFlagForUserAndOperandTypes(IREE::LinalgExt::EncodingAttr encoding,
}
}

static uint32_t getFlagForRole(IREE::LinalgExt::EncodingRole role) {
static uint32_t getFlagForRole(IREE::Encoding::EncodingRole role) {
switch (role) {
case IREE::LinalgExt::EncodingRole::LHS:
case IREE::Encoding::EncodingRole::LHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
case IREE::LinalgExt::EncodingRole::RHS:
case IREE::Encoding::EncodingRole::RHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
case IREE::LinalgExt::EncodingRole::RESULT:
case IREE::Encoding::EncodingRole::RESULT:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
default:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_NONE;
Expand All @@ -534,8 +534,8 @@ matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op,
if (tensorType.getRank() != 2) {
return rewriter.notifyMatchFailure(op, "only the 2D case is implemented");
}
auto encoding = tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
auto encoding =
tensorType.getEncoding().dyn_cast_or_null<IREE::Encoding::EncodingAttr>();
if (!encoding) {
return rewriter.notifyMatchFailure(op, "no encoding attribute");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
Expand All @@ -30,7 +29,7 @@

namespace mlir::iree_compiler {

using namespace IREE::LinalgExt;
using namespace IREE::Encoding;
using IREE::HAL::ExecutableTargetAttr;

// Enumerate tile sizes to choose from when no specific architecture is
Expand Down Expand Up @@ -432,7 +431,6 @@ struct CPUMaterializeEncodingPass
: targetAttr(attr) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, tensor::TensorDialect,
IREE::LinalgExt::IREELinalgExtDialect,
IREE::Codegen::IREECodegenDialect>();
}
void runOnOperation() override;
Expand Down Expand Up @@ -460,9 +458,8 @@ struct CPUMaterializeUpperBoundTileSizePass
FailureOr<MaterializeEncodingInfo>
materializeEncodingForTarget(RankedTensorType tensorType,
ExecutableTargetAttr targetAttr) {
IREE::LinalgExt::EncodingAttr encoding =
tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
IREE::Encoding::EncodingAttr encoding =
tensorType.getEncoding().dyn_cast_or_null<IREE::Encoding::EncodingAttr>();
if (!encoding) {
return failure();
}
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
namespace mlir::iree_compiler {

/// Convert encoding-specific operations based on target attributes. Examples:
/// linalg_ext.set_encoding -> tensor.pack
/// linalg_ext.unset_encoding -> tensor.unpack
/// encoding.set_encoding -> tensor.pack
/// encoding.unset_encoding -> tensor.unpack
/// linalg.matmul -> linalg.mmt4d
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCPUMaterializeEncodingPass(
IREE::HAL::ExecutableTargetAttr targetAttr = nullptr);

/// Like createLLVMCPUMaterializeEncodingPass, but specifically for
/// linalg_ext.upper_bound_tile_size, converting it to constants.
/// encoding.upper_bound_tile_size, converting it to constants.
///
/// Unlike createLLVMCPUMaterializeEncodingPass, this does not require the
/// op to have a specific HAL target attribute. Instead, this will iterate over
Expand Down
Loading

0 comments on commit 792f14d

Please sign in to comment.