diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp index c7517d8bca1a..a7c1f00fe1d2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" @@ -481,7 +482,8 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp, // 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU, // so it is nice that they have fewer narrow cases to consider. MaterializeEncodingTypeConverter typeConverter( - materializeEncodingForTarget, targetAttr, /*transposeNarrowN=*/true); + materializeEncodingForTarget, targetAttr, /*transposeNarrowN=*/true, + /*layoutAttr=*/{}); MaterializeEncodingConversionTarget target(*ctx); auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr); populateMaterializeEncodingIntoPackUnPackPatterns( diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp index fd75e74a987e..0464fab4e4ad 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp @@ -89,9 +89,10 @@ static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) { MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter( MaterializeEncodingFn materializeEncodingFn, - IREE::HAL::ExecutableTargetAttr targetAttr, bool transposeNarrowN) + IREE::HAL::ExecutableTargetAttr targetAttr, bool transposeNarrowN, + IREE::Codegen::LayoutAttrInterface layoutAttr) : materializeEncodingFn(materializeEncodingFn), targetAttr(targetAttr), - transposeNarrowN(transposeNarrowN) { + transposeNarrowN(transposeNarrowN), layoutAttr(layoutAttr) { addConversion([](IntegerType intType) { return intType; }); addConversion([](IndexType indexType) { return indexType; }); addConversion([](FloatType floatType) { return floatType; }); diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h index 7077fb6a05f1..a88c36803b39 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h @@ -7,7 +7,8 @@ #ifndef IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_ #define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_ -#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -34,9 +35,13 @@ using MaterializeEncodingValueFn = /// TypeConverter to use for materializing the encoding. class MaterializeEncodingTypeConverter : public TypeConverter { public: - MaterializeEncodingTypeConverter(MaterializeEncodingFn fn, - IREE::HAL::ExecutableTargetAttr targetAttr, - bool transposeNarrowN); + MaterializeEncodingTypeConverter( + MaterializeEncodingFn fn, IREE::HAL::ExecutableTargetAttr targetAttr, + bool transposeNarrowN, IREE::Codegen::LayoutAttrInterface layoutAttr); + + const IREE::Codegen::LayoutAttrInterface &getLayoutAttr() const { + return layoutAttr; + } const MaterializeEncodingFn &getMaterializeEncodingFn() const { return materializeEncodingFn; @@ -46,6 +51,9 @@ class MaterializeEncodingTypeConverter : public TypeConverter { FailureOr getEncodingInfo(RankedTensorType type) const { + if (layoutAttr) { + return layoutAttr.getEncodingInfo(type); + } return materializeEncodingFn(type, targetAttr); } @@ -55,6 +63,13 @@ class MaterializeEncodingTypeConverter : public TypeConverter { const MaterializeEncodingFn materializeEncodingFn; const IREE::HAL::ExecutableTargetAttr targetAttr; bool transposeNarrowN = false; + // The `layoutAttr` implements the logic of encoding materialization. It has + // a higher priority when it is present. + // TODO(hanchung): Move the logic that takes `targetAttr` and + // `transposeNarrowN` into account to their own attribute implementation. It + // is in a transition state, so we have two paths atm. We're incrementally + // moving the logic to attributes. + const IREE::Codegen::LayoutAttrInterface layoutAttr; }; /// Conversion target to use for for materializing the encoding. @@ -86,17 +101,15 @@ class OpMaterializeEncodingPattern : public OpConversionPattern { RankedTensorType dropEncoding(RankedTensorType type); /// Utility method to convert from `set_encoding` op to `pack` operation. -/// For now this takes a `paddingValue` as input. The source is also taken -/// as input so that these could be used with `OpConversionPatterns`. -FailureOr lowerSetEncodingOpToPackOp( +/// NOTE: `source` could be returned when packing is not needed. +FailureOr lowerSetEncodingOpToPackOp( RewriterBase &rewriter, IREE::Encoding::SetEncodingOp encodingOp, Value source, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn); /// Utility method to convert from `unset_encoding` op to `unpack` operation. -/// The source is taken as input so that these could be used with -/// `OpConversionPatterns`. -FailureOr lowerUnsetEncodingToUnpackOp( +/// NOTE: `packedValue` could be returned when unpacking is not needed. +FailureOr lowerUnsetEncodingToUnpackOp( RewriterBase &rewriter, IREE::Encoding::UnsetEncodingOp encodingOp, Value packedValue, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 6debc2a8ffbc..d85686c35a51 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -315,10 +315,10 @@ struct GPUSetEncodingOpLoweringConversion ConversionPatternRewriter &rewriter) const override { auto converter = static_cast( getTypeConverter()); - auto packOp = lowerSetEncodingOpToPackOp(rewriter, encodingOp, - adaptor.getSource(), *converter, - this->materializeEncodingValueFn); - if (failed(packOp)) { + auto packedValue = lowerSetEncodingOpToPackOp( + rewriter, encodingOp, adaptor.getSource(), *converter, + this->materializeEncodingValueFn); + if (failed(packedValue)) { Type targetType = getTypeConverter()->convertType(encodingOp.getResultType()); Value result = rewriter.createOrFold( @@ -334,7 +334,7 @@ struct GPUSetEncodingOpLoweringConversion "unhandled result encoding"); } if (!maybeEncodingInfo->swizzle) { - rewriter.replaceOp(encodingOp, packOp->getResult()); + rewriter.replaceOp(encodingOp, packedValue.value()); return success(); } @@ -343,7 +343,9 @@ struct GPUSetEncodingOpLoweringConversion // Create expand_shape op to tile the innermost two dimensions. int origRank = encodingOp.getSourceType().getRank(); SmallVector expandShapeShape( - packOp->getDestType().getShape().take_front(origRank)); + cast(packedValue->getType()) + .getShape() + .take_front(origRank)); expandShapeShape.append( getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape)); RankedTensorType expandShapeType = @@ -352,7 +354,7 @@ struct GPUSetEncodingOpLoweringConversion SmallVector reassociation = getReassociationIndices( origRank, maybeEncodingInfo->swizzle->expandShape); auto expandShapeOp = rewriter.create( - loc, expandShapeType, packOp->getResult(), reassociation); + loc, expandShapeType, packedValue.value(), reassociation); SmallVector transposePerm = llvm::to_vector(llvm::seq(0, origRank)); @@ -433,10 +435,10 @@ struct GPUUnsetEncodingOpLoweringConversion loc, unpackSrcType, transposeOp->getResult(0), reassociation); } - auto unPackOp = lowerUnsetEncodingToUnpackOp( + auto unpackedValue = lowerUnsetEncodingToUnpackOp( rewriter, unsetEncodingOp, unpackSrc, *converter, this->materializeEncodingValueFn); - if (failed(unPackOp)) { + if (failed(unpackedValue)) { Type targetType = getTypeConverter()->convertType(unsetEncodingOp.getResultType()); Value result = rewriter.createOrFold(loc, targetType, @@ -444,7 +446,7 @@ struct GPUUnsetEncodingOpLoweringConversion rewriter.replaceOp(unsetEncodingOp, result); return success(); } - rewriter.replaceOp(unsetEncodingOp, unPackOp->getResult()); + rewriter.replaceOp(unsetEncodingOp, unpackedValue.value()); return success(); } }; @@ -559,7 +561,8 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp, // 3. Heuristics for cache-friendly dispatch tiling are internal to the GPU // runtime, so we don't need a simplification at that level either. MaterializeEncodingTypeConverter typeConverter( - materializeEncodingForTarget, targetAttr, /*transposeNarrowN=*/false); + materializeEncodingForTarget, targetAttr, /*transposeNarrowN=*/false, + /*layoutAttr=*/{}); MaterializeEncodingConversionTarget target(*ctx); MaterializeEncodingValueFn materializeEncodingValueFn = [](RankedTensorType, OpBuilder, diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp index 32eb822c189d..f9e1fc53bc94 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoNop.cpp @@ -7,6 +7,8 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/PassUtils.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" @@ -28,7 +30,8 @@ namespace { struct MaterializeEncodingIntoNopPass final : impl::MaterializeEncodingIntoNopPassBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { @@ -47,7 +50,8 @@ struct MaterializeEncodingIntoNopPass final RewritePatternSet materializeEncodingPattern(context); MaterializeEncodingTypeConverter typeConverter( materializeEncodingFn, IREE::HAL::ExecutableTargetAttr(), - /*transposeNarrowN=*/false); + /*transposeNarrowN=*/false, + IREE::Codegen::EncodingNopLayoutAttr::get(context)); MaterializeEncodingConversionTarget target(*context); populateMaterializeEncodingIntoPackUnPackPatterns( materializeEncodingPattern, typeConverter, materializeEncodingValueFn); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index fc3bb45c8be6..4d36b53af00c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -10,12 +10,14 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -119,12 +121,7 @@ static void transposeInPlace(MaterializeEncodingInfo &info) { // to `pack` and `unpack` operations respectively. //===---------------------------------------------------------------------===// -/// TODO(hanchung): Move the implementation to EncodingUtils.cpp. It is not -/// moved because it needs some cleanup for this file. E.g., `getPaddingValue` -/// is no longer needed. Ideally we should move CPU specific patterns (e.g., -/// lowerContractionOpWithEncoding, etc) to the CPUMaterializeEncoding file; -/// move general patterns to EncodingUtils, and retire this file. -FailureOr lowerSetEncodingOpToPackOp( +FailureOr lowerSetEncodingOpToPackOp( RewriterBase &rewriter, IREE::Encoding::SetEncodingOp encodingOp, Value source, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn) { @@ -135,6 +132,11 @@ FailureOr lowerSetEncodingOpToPackOp( return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding"); } + // Shortcut to avoid creating new operations. + if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) { + return source; + } + auto encoding = IREE::Encoding::getEncodingAttr(resultType); if (!encoding) { return failure(); @@ -160,14 +162,14 @@ FailureOr lowerSetEncodingOpToPackOp( encodingInfo->outerDimsPerm); auto emptyOp = rewriter.create(loc, resultDims, resultType.getElementType()); - return rewriter.create( - loc, source, emptyOp, encodingInfo->innerDimsPos, *innerTileSizesOfr, - paddingValue, encodingInfo->outerDimsPerm); + return rewriter + .create(loc, source, emptyOp, encodingInfo->innerDimsPos, + *innerTileSizesOfr, paddingValue, + encodingInfo->outerDimsPerm) + .getResult(); } -/// TODO(hanchung): Move the implementation to EncodingUtils.cpp. See the reason -/// in the implementation comment of lowerSetEncodingToPackOp method. -FailureOr lowerUnsetEncodingToUnpackOp( +FailureOr lowerUnsetEncodingToUnpackOp( RewriterBase &rewriter, IREE::Encoding::UnsetEncodingOp encodingOp, Value packedValue, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn) { @@ -177,6 +179,12 @@ FailureOr lowerUnsetEncodingToUnpackOp( if (failed(encodingInfo)) { return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding"); } + + // Shortcut to avoid creating new operations. + if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) { + return packedValue; + } + auto encoding = IREE::Encoding::getEncodingAttr(sourceType); if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) { transposeInPlace(*encodingInfo); @@ -194,9 +202,11 @@ FailureOr lowerUnsetEncodingToUnpackOp( return rewriter.notifyMatchFailure( encodingOp, "failed to generate runtime tile size query"); } - return rewriter.create( - loc, packedValue, emptyOp, encodingInfo->innerDimsPos, *innerTileSizesOfr, - encodingInfo->outerDimsPerm); + return rewriter + .create(loc, packedValue, emptyOp, + encodingInfo->innerDimsPos, *innerTileSizesOfr, + encodingInfo->outerDimsPerm) + .getResult(); } /// Utility method to convert `tensor.empty` with encoding to a `tensor.empty` @@ -609,7 +619,7 @@ struct SetEncodingOpToPackOpConversion rewriter.replaceOp(encodingOp, result); return success(); } - rewriter.replaceOp(encodingOp, packOp->getResult()); + rewriter.replaceOp(encodingOp, packOp.value()); return success(); } }; @@ -625,10 +635,10 @@ struct UnsetEncodingOpToUnPackOpConversion ConversionPatternRewriter &rewriter) const override { auto converter = static_cast( this->getTypeConverter()); - auto unpackOp = lowerUnsetEncodingToUnpackOp( + auto unpackedValue = lowerUnsetEncodingToUnpackOp( rewriter, encodingOp, adaptor.getSource(), *converter, this->materializeEncodingValueFn); - if (failed(unpackOp)) { + if (failed(unpackedValue)) { Type targetType = getTypeConverter()->convertType(encodingOp.getResultType()); Value result = rewriter.createOrFold( @@ -636,7 +646,7 @@ struct UnsetEncodingOpToUnPackOpConversion rewriter.replaceOp(encodingOp, result); return success(); } - rewriter.replaceOp(encodingOp, unpackOp->getResult()); + rewriter.replaceOp(encodingOp, unpackedValue.value()); return success(); } }; @@ -734,6 +744,18 @@ class MaterializeContractionOp auto converter = static_cast( this->getTypeConverter()); + + if (auto layoutAttr = converter->getLayoutAttr()) { + SmallVector convertedResTypes; + for (auto init : op.getDpsInits()) { + convertedResTypes.push_back(converter->convertType(init.getType())); + } + Operation *newOp = + layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + // TODO(hanchung): This is a transition state for moving the implementation // details to backend attributes. We won't need the function type argument // after all the backends that support encodings implement the attribute. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp index b5403a2b1f88..dde48b9b1ed7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp @@ -8,11 +8,13 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/StorageUniquerSupport.h" @@ -460,6 +462,21 @@ int64_t WorkgroupMappingAttr::getRelativeIndex() const { return getMappingId(); } +//===---------------------------------------------------------------------===// +// iree_codegen.encoding_layout +//===---------------------------------------------------------------------===// + +MaterializeEncodingInfo +EncodingNopLayoutAttr::getEncodingInfo(RankedTensorType type) const { + return MaterializeEncodingInfo{}; +} + +Operation *EncodingNopLayoutAttr::lowerOp(OpBuilder &b, Operation *op, + TypeRange convertedResTypes, + ValueRange convertedOperands) const { + return clone(b, op, convertedResTypes, convertedOperands); +} + //===----------------------------------------------------------------------===// // Initialize attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 3086c09b2069..26b37dd07e24 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -418,4 +418,24 @@ def IREECodegen_ExportConfig : AttrDef let genVerifyDecl = 1; } +//===---------------------------------------------------------------------===// +// iree_codegen.encoding_layout +//===---------------------------------------------------------------------===// + +def IREECodegen_EncodingNopLayoutAttr : + AttrDef + ]> { + let mnemonic = "encoding_nop_layout"; + let summary = "An attribute with implementation that treats encoding as nop."; + let description = [{ + An attribute that implements the interface methods that discards the + encodings. It can be a default attribute when a backend does not implement + encoding details. + }]; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h index b8a026db441c..c35058fd46ba 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IREECODEGENINTERFACES_H_ #define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IREECODEGENINTERFACES_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td index 2f8dffb39792..36d5d3db31ec 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td @@ -116,4 +116,44 @@ def IREECodegen_LoweringConfigAttrInterface : ]; } +def IREECodegen_LayoutAttrInterface : + AttrInterface<"LayoutAttrInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Codegen"; + let description = [{ + An interface that collects a set of methods for encoding materialization. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the layout of materialized encoding for a tensor type. + }], + /*retTy=*/"::mlir::iree_compiler::IREE::Codegen::MaterializeEncodingInfo", + /*methodName=*/"getEncodingInfo", + /*args=*/(ins "::mlir::RankedTensorType":$type), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(false && "unimplemented interface method"); + return MaterializeEncodingInfo{}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the layout of materialized encoding for a tensor type. + }], + /*retTy=*/"::mlir::Operation *", + /*methodName=*/"lowerOp", + /*args=*/(ins "::mlir::OpBuilder &":$b, + "::mlir::Operation *":$op, + "::mlir::TypeRange":$convertedResTypes, + "::mlir::ValueRange":$convertedOperands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(false && "unimplemented interface method"); + return nullptr; + }] + > + ]; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IREECODEGENINTERFACES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 4a12f7013417..7976b7ed8ba6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -244,6 +244,14 @@ deserializeEncodingInfo(DictionaryAttr attr) { return info; } +bool isIdentityLayout(const MaterializeEncodingInfo &info) { + // It is not an identity layout if swizzle is present. The swizzle is an + // optional variable. User should not set the field when they do not need + // swizzle. + return info.innerDimsPos.empty() && info.innerTileSizes.empty() && + info.outerDimsPerm.empty() && !info.swizzle; +} + SmallVector getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) { SmallVector result; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h index b1997c1b91fd..8498a95e11a3 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h @@ -57,6 +57,10 @@ DictionaryAttr serializeEncodingInfo(MLIRContext *ctx, std::optional deserializeEncodingInfo(DictionaryAttr attr); +/// Returns true if the `info` denotes an identity layout, i.e., there is no +/// relayout requirement. +bool isIdentityLayout(const MaterializeEncodingInfo &info); + /// Concatenates the vectors. SmallVector getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/BUILD.bazel index abe11581e12c..8159dea713f1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/BUILD.bazel @@ -16,6 +16,7 @@ iree_compiler_cc_test( testonly = True, srcs = ["UtilsTest.cpp"], deps = [ + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils", "//compiler/src/iree/testing:gtest_main", "@com_google_googletest//:gtest", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/CMakeLists.txt index bf20bd2bad9d..6624ee4fcee2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/CMakeLists.txt @@ -19,6 +19,7 @@ iree_cc_test( MLIRIR gmock gtest + iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::Codegen::Utils iree::testing::gtest_main ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp index 82f482761da7..418626900ec9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp @@ -7,6 +7,7 @@ #include #include +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" @@ -186,5 +187,12 @@ TEST(MaterializeEncodingInfo, Deserialization) { EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value()); } +TEST(MaterializeEncodingInfo, IdentityLayout) { + MaterializeEncodingInfo info; + EXPECT_TRUE(isIdentityLayout(info)); + info.swizzle = TileSwizzle(); + EXPECT_FALSE(isIdentityLayout(info)); +} + } // namespace } // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index 3797824fa122..d85310e8dfe4 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -78,6 +78,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses", "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow", "//compiler/src/iree/compiler/Dialect/Flow/IR", diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 70bd927bfc7e..9ca16eed433d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -93,6 +93,7 @@ iree_cc_library( iree::compiler::Codegen::Common iree::compiler::Codegen::Common::CPU::CommonCPUPasses iree::compiler::Codegen::Common::GPU::CommonGPUPasses + iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::Flow::Conversion::TensorToFlow iree::compiler::Dialect::Flow::IR diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index 4ce2d92d5748..adcc12977bad 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/CPU/Passes.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" @@ -44,7 +45,8 @@ class MaterializeHomogeneousEncodingsPass MaterializeHomogeneousEncodingsPass> { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void addNopPipeline(OpPassManager &passManager) {