diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h index 4131eb53f07625..f42c40eb68902b 100644 --- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h +++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h @@ -49,20 +49,20 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter { // i32 is used here because LLVM wants i32 constants when indexing into struct // types. Indexing into other aggregate types is more flexible. - mlir::Type offsetType(); + mlir::Type offsetType() const; // i64 can be used to index into aggregates like arrays - mlir::Type indexType(); + mlir::Type indexType() const; // fir.type --> llvm<"%name = { ty... }"> std::optional convertRecordType(fir::RecordType derived, llvm::SmallVectorImpl &results, - llvm::ArrayRef callStack); + llvm::ArrayRef callStack) const; // Is an extended descriptor needed given the element type of a fir.box type ? // Extended descriptors are required for derived types. - bool requiresExtendedDesc(mlir::Type boxElementType); + bool requiresExtendedDesc(mlir::Type boxElementType) const; // Magic value to indicate we do not know the rank of an entity, either // because it is assumed rank or because we have not determined it yet. @@ -70,35 +70,33 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter { // This corresponds to the descriptor as defined in ISO_Fortran_binding.h and // the addendum defined in descriptor.h. - mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank()); + mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank()) const; /// Convert fir.box type to the corresponding llvm struct type instead of a /// pointer to this struct type. - mlir::Type convertBoxTypeAsStruct(BaseBoxType box); + mlir::Type convertBoxTypeAsStruct(BaseBoxType box) const; // fir.boxproc --> llvm<"{ any*, i8* }"> - mlir::Type convertBoxProcType(BoxProcType boxproc); + mlir::Type convertBoxProcType(BoxProcType boxproc) const; - unsigned characterBitsize(fir::CharacterType charTy); + unsigned characterBitsize(fir::CharacterType charTy) const; // fir.char --> llvm<"ix"> where ix is scaled by kind mapping // fir.char --> llvm.array - mlir::Type convertCharType(fir::CharacterType charTy); + mlir::Type convertCharType(fir::CharacterType charTy) const; // Use the target specifics to figure out how to map complex to LLVM IR. The // use of complex values in function signatures is handled before conversion // to LLVM IR dialect here. // // fir.complex | std.complex --> llvm<"{t,t}"> - template - mlir::Type convertComplexType(C cmplx) { + template mlir::Type convertComplexType(C cmplx) const { LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n'); auto eleTy = cmplx.getElementType(); return convertType(specifics->complexMemoryType(eleTy)); } - template - mlir::Type convertPointerLike(A &ty) { + template mlir::Type convertPointerLike(A &ty) const { mlir::Type eleTy = ty.getEleTy(); // A sequence type is a special case. A sequence of runtime size on its // interior dimensions lowers to a memory reference. In that case, we @@ -126,27 +124,27 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter { // convert a front-end kind value to either a std or LLVM IR dialect type // fir.real --> llvm.anyfloat where anyfloat is a kind mapping - mlir::Type convertRealType(fir::KindTy kind); + mlir::Type convertRealType(fir::KindTy kind) const; // fir.array --> llvm<"[...[c x any]]"> - mlir::Type convertSequenceType(SequenceType seq); + mlir::Type convertSequenceType(SequenceType seq) const; // fir.tdesc --> llvm<"i8*"> // TODO: For now use a void*, however pointer identity is not sufficient for // the f18 object v. class distinction (F2003). - mlir::Type convertTypeDescType(mlir::MLIRContext *ctx); + mlir::Type convertTypeDescType(mlir::MLIRContext *ctx) const; - KindMapping &getKindMap() { return kindMapping; } + const KindMapping &getKindMap() const { return kindMapping; } // Relay TBAA tag attachment to TBAABuilder. void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op, mlir::Type baseFIRType, mlir::Type accessFIRType, - mlir::LLVM::GEPOp gep); + mlir::LLVM::GEPOp gep) const; private: KindMapping kindMapping; std::unique_ptr specifics; - TBAABuilder tbaaBuilder; + std::unique_ptr tbaaBuilder; }; } // namespace fir diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 596458d37d2dcd..0fbee616ac9a51 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -117,7 +117,7 @@ namespace { template class FIROpConversion : public mlir::ConvertOpToLLVMPattern { public: - explicit FIROpConversion(fir::LLVMTypeConverter &lowering, + explicit FIROpConversion(const fir::LLVMTypeConverter &lowering, const fir::FIRToLLVMPassOptions &options) : mlir::ConvertOpToLLVMPattern(lowering), options(options) {} @@ -359,8 +359,9 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { return al; } - fir::LLVMTypeConverter &lowerTy() const { - return *static_cast(this->getTypeConverter()); + const fir::LLVMTypeConverter &lowerTy() const { + return *static_cast( + this->getTypeConverter()); } void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op, @@ -3191,8 +3192,8 @@ struct SelectCaseOpConversion : public FIROpConversion { }; template -static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, - typename OP::Adaptor adaptor, +static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, + OP select, typename OP::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) { unsigned conds = select.getNumConditions(); auto cases = select.getCases().getValue(); @@ -3461,7 +3462,7 @@ template static mlir::LLVM::InsertValueOp complexSum(OPTY sumop, mlir::ValueRange opnds, mlir::ConversionPatternRewriter &rewriter, - fir::LLVMTypeConverter &lowering) { + const fir::LLVMTypeConverter &lowering) { mlir::Value a = opnds[0]; mlir::Value b = opnds[1]; auto loc = sumop.getLoc(); @@ -3610,7 +3611,7 @@ struct NegcOpConversion : public FIROpConversion { /// These operations are normally dead after the pre-codegen pass. template struct MustBeDeadConversion : public FIROpConversion { - explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering, + explicit MustBeDeadConversion(const fir::LLVMTypeConverter &lowering, const fir::FIRToLLVMPassOptions &options) : FIROpConversion(lowering, options) {} using OpAdaptor = typename FromOp::Adaptor; diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp index 8de2dbbca3f806..fd5f0c7135fea2 100644 --- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp @@ -37,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA) specifics(CodeGenSpecifics::get(module.getContext(), getTargetTriple(module), getKindMapping(module))), - tbaaBuilder(module->getContext(), applyTBAA) { + tbaaBuilder( + std::make_unique(module->getContext(), applyTBAA)) { LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n"); // Each conversion should return a value of type mlir::Type. @@ -155,20 +156,19 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA) // i32 is used here because LLVM wants i32 constants when indexing into struct // types. Indexing into other aggregate types is more flexible. -mlir::Type LLVMTypeConverter::offsetType() { +mlir::Type LLVMTypeConverter::offsetType() const { return mlir::IntegerType::get(&getContext(), 32); } // i64 can be used to index into aggregates like arrays -mlir::Type LLVMTypeConverter::indexType() { +mlir::Type LLVMTypeConverter::indexType() const { return mlir::IntegerType::get(&getContext(), 64); } // fir.type --> llvm<"%name = { ty... }"> -std::optional -LLVMTypeConverter::convertRecordType(fir::RecordType derived, - llvm::SmallVectorImpl &results, - llvm::ArrayRef callStack) { +std::optional LLVMTypeConverter::convertRecordType( + fir::RecordType derived, llvm::SmallVectorImpl &results, + llvm::ArrayRef callStack) const { auto name = derived.getName(); auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name); if (llvm::count(callStack, derived) > 1) { @@ -192,14 +192,14 @@ LLVMTypeConverter::convertRecordType(fir::RecordType derived, // Is an extended descriptor needed given the element type of a fir.box type ? // Extended descriptors are required for derived types. -bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) { +bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) const { auto eleTy = fir::unwrapSequenceType(boxElementType); return eleTy.isa(); } // This corresponds to the descriptor as defined in ISO_Fortran_binding.h and // the addendum defined in descriptor.h. -mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) { +mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const { // (base_addr*, elem_len, version, rank, type, attribute, f18Addendum, [dim] llvm::SmallVector dataDescFields; mlir::Type ele = box.getEleTy(); @@ -269,14 +269,14 @@ mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) { /// Convert fir.box type to the corresponding llvm struct type instead of a /// pointer to this struct type. -mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) { +mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) const { return convertBoxType(box) .cast() .getElementType(); } // fir.boxproc --> llvm<"{ any*, i8* }"> -mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) { +mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const { auto funcTy = convertType(boxproc.getEleTy()); auto i8PtrTy = mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(&getContext(), 8)); @@ -285,13 +285,13 @@ mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) { /*isPacked=*/false); } -unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) { +unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) const { return kindMapping.getCharacterBitsize(charTy.getFKind()); } // fir.char --> llvm<"ix"> where ix is scaled by kind mapping // fir.char --> llvm.array -mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) { +mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) const { auto iTy = mlir::IntegerType::get(&getContext(), characterBitsize(charTy)); if (charTy.getLen() == fir::CharacterType::unknownLen()) return iTy; @@ -300,13 +300,13 @@ mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) { // convert a front-end kind value to either a std or LLVM IR dialect type // fir.real --> llvm.anyfloat where anyfloat is a kind mapping -mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) { +mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) const { return fir::fromRealTypeID(&getContext(), kindMapping.getRealTypeID(kind), kind); } // fir.array --> llvm<"[...[c x any]]"> -mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) { +mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const { auto baseTy = convertType(seq.getEleTy()); if (characterWithDynamicLen(seq.getEleTy())) return mlir::LLVM::LLVMPointerType::get(baseTy); @@ -328,7 +328,8 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) { // fir.tdesc --> llvm<"i8*"> // TODO: For now use a void*, however pointer identity is not sufficient for // the f18 object v. class distinction (F2003). -mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) { +mlir::Type +LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(&getContext(), 8)); } @@ -337,8 +338,8 @@ mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) { void LLVMTypeConverter::attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op, mlir::Type baseFIRType, mlir::Type accessFIRType, - mlir::LLVM::GEPOp gep) { - tbaaBuilder.attachTBAATag(op, baseFIRType, accessFIRType, gep); + mlir::LLVM::GEPOp gep) const { + tbaaBuilder->attachTBAATag(op, baseFIRType, accessFIRType, gep); } } // namespace fir diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h index 28d37a91edb80d..ef8215d332c463 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -40,13 +40,14 @@ class MemRefDescriptor : public StructBuilder { /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. - static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - MemRefType type, Value memory); - static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - MemRefType type, Value memory, - Value alignedMemory); + static MemRefDescriptor + fromStaticShape(OpBuilder &builder, Location loc, + const LLVMTypeConverter &typeConverter, MemRefType type, + Value memory); + static MemRefDescriptor + fromStaticShape(OpBuilder &builder, Location loc, + const LLVMTypeConverter &typeConverter, MemRefType type, + Value memory, Value alignedMemory); /// Builds IR extracting the allocated pointer from the descriptor. Value allocatedPtr(OpBuilder &builder, Location loc); @@ -95,7 +96,7 @@ class MemRefDescriptor : public StructBuilder { /// \note there is no setter for this one since it is derived from alignedPtr /// and offset. Value bufferPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, MemRefType type); + const LLVMTypeConverter &converter, MemRefType type); /// Builds IR populating a MemRef descriptor structure from a list of /// individual values composing that descriptor, in the following order: @@ -106,7 +107,7 @@ class MemRefDescriptor : public StructBuilder { /// - shapes; /// where is the MemRef rank as provided in `type`. static Value pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, MemRefType type, + const LLVMTypeConverter &converter, MemRefType type, ValueRange values); /// Builds IR extracting individual elements of a MemRef descriptor structure @@ -178,7 +179,7 @@ class UnrankedMemRefDescriptor : public StructBuilder { /// - rank of the memref; /// - pointer to the memref descriptor. static Value pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, UnrankedMemRefType type, + const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values); /// Builds IR extracting individual elements that compose an unranked memref @@ -195,7 +196,7 @@ class UnrankedMemRefDescriptor : public StructBuilder { /// which must have the same length as `values`, is needed to handle layouts /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)). static void computeSizes(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, ArrayRef values, ArrayRef addressSpaces, SmallVectorImpl &sizes); @@ -217,11 +218,12 @@ class UnrankedMemRefDescriptor : public StructBuilder { /// Builds IR extracting the aligned pointer from the descriptor. static Value alignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value memRefDescPtr, + const LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType); /// Builds IR inserting the aligned pointer into the descriptor. static void setAlignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr); @@ -230,44 +232,45 @@ class UnrankedMemRefDescriptor : public StructBuilder { /// Returns a pointer to a convertType(index), which points to the beggining /// of a struct {index, index[rank], index[rank]}. static Value offsetBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType); /// Builds IR extracting the offset from the descriptor. static Value offset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrType); + const LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType); /// Builds IR inserting the offset into the descriptor. static void setOffset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrType, Value offset); + const LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, + Value offset); /// Builds IR extracting the pointer to the first element of the size array. static Value sizeBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType); /// Builds IR extracting the size[index] from the descriptor. static Value size(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value sizeBasePtr, + const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index); /// Builds IR inserting the size[index] into the descriptor. static void setSize(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value sizeBasePtr, + const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size); /// Builds IR extracting the pointer to the first element of the stride array. static Value strideBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank); /// Builds IR extracting the stride[index] from the descriptor. static Value stride(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value strideBasePtr, - Value index, Value stride); + const LLVMTypeConverter &typeConverter, + Value strideBasePtr, Value index, Value stride); /// Builds IR inserting the stride[index] into the descriptor. static void setStride(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value strideBasePtr, - Value index, Value stride); + const LLVMTypeConverter &typeConverter, + Value strideBasePtr, Value index, Value stride); }; } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 075d753ea6ed82..92f4025ffffffb 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -23,7 +23,7 @@ namespace detail { LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef targetAttrs, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); } // namespace detail @@ -37,14 +37,14 @@ LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, class ConvertToLLVMPattern : public ConversionPattern { public: ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); protected: /// Returns the LLVM dialect. LLVM::LLVMDialect &getDialect() const; - LLVMTypeConverter *getTypeConverter() const; + const LLVMTypeConverter *getTypeConverter() const; /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. @@ -140,7 +140,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; - explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertToLLVMPattern(SourceOp::getOperationName(), &typeConverter.getContext(), typeConverter, diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index 79a68e875f045e..2097aa78ebd70e 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -33,7 +33,7 @@ class LLVMStructType; class LLVMTypeConverter : public TypeConverter { /// Give structFuncArgTypeConverter access to memref-specific functions. friend LogicalResult - structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, + structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); public: @@ -54,20 +54,20 @@ class LLVMTypeConverter : public TypeConverter { /// is populated with argument mapping. Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, - SignatureConversion &result); + SignatureConversion &result) const; /// Convert a non-empty list of types to be returned from a function into an /// LLVM-compatible type. In particular, if more than one value is returned, /// create an LLVM dialect structure type with elements that correspond to /// each of the types converted with `convertCallingConventionType`. Type packFunctionResults(TypeRange types, - bool useBarePointerCallConv = false); + bool useBarePointerCallConv = false) const; /// Convert a non-empty list of types of values produced by an operation into /// an LLVM-compatible type. In particular, if more than one value is /// produced, create a literal structure with elements that correspond to each /// of the LLVM-compatible types converted with `convertType`. - Type packOperationResults(TypeRange types); + Type packOperationResults(TypeRange types) const; /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and @@ -75,20 +75,20 @@ class LLVMTypeConverter : public TypeConverter { /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. Type convertCallingConventionType(Type type, - bool useBarePointerCallConv = false); + bool useBarePointerCallConv = false) const; /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds the types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, - SmallVectorImpl &values); + SmallVectorImpl &values) const; /// Returns the MLIR context. - MLIRContext &getContext(); + MLIRContext &getContext() const; /// Returns the LLVM dialect. - LLVM::LLVMDialect *getDialect() { return llvmDialect; } + LLVM::LLVMDialect *getDialect() const { return llvmDialect; } const LowerToLLVMOptions &getOptions() const { return options; } @@ -105,23 +105,23 @@ class LLVMTypeConverter : public TypeConverter { /// passing. SmallVector promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, - bool useBarePtrCallConv = false); + bool useBarePtrCallConv = false) const; /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific /// C/C++ ABI lowering related to struct argument passing. Value promoteOneMemRefDescriptor(Location loc, Value operand, - OpBuilder &builder); + OpBuilder &builder) const; /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. Also converts the return /// type to a pointer argument if it is a struct. Returns true if this /// was the case. std::pair - convertFunctionTypeCWrapper(FunctionType type); + convertFunctionTypeCWrapper(FunctionType type) const; /// Returns the data layout to use during and after conversion. - const llvm::DataLayout &getDataLayout() { return options.dataLayout; } + const llvm::DataLayout &getDataLayout() const { return options.dataLayout; } /// Returns the data layout analysis to query during conversion. const DataLayoutAnalysis *getDataLayoutAnalysis() const { @@ -130,7 +130,7 @@ class LLVMTypeConverter : public TypeConverter { /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. - Type getIndexType(); + Type getIndexType() const; /// Returns true if using opaque pointers was enabled in the lowering options. bool useOpaquePointers() const { return getOptions().useOpaquePointers; } @@ -141,25 +141,26 @@ class LLVMTypeConverter : public TypeConverter { /// pointers, as it will create an opaque pointer with the given address space /// if opaque pointers are enabled in the lowering options. LLVM::LLVMPointerType getPointerType(Type elementType, - unsigned addressSpace = 0); + unsigned addressSpace = 0) const; /// Gets the bitwidth of the index type when converted to LLVM. - unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); } + unsigned getIndexTypeBitwidth() const { return options.getIndexBitwidth(); } /// Gets the pointer bitwidth. - unsigned getPointerBitwidth(unsigned addressSpace = 0); + unsigned getPointerBitwidth(unsigned addressSpace = 0) const; /// Returns the size of the memref descriptor object in bytes. - unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout); + unsigned getMemRefDescriptorSize(MemRefType type, + const DataLayout &layout) const; /// Returns the size of the unranked memref descriptor object in bytes. unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, - const DataLayout &layout); + const DataLayout &layout) const; /// Return the LLVM address space corresponding to the memory space of the /// memref type `type` or failure if the memory space cannot be converted to /// an integer. - FailureOr getMemRefAddressSpace(BaseMemRefType type); + FailureOr getMemRefAddressSpace(BaseMemRefType type) const; /// Check if a memref type can be converted to a bare pointer. static bool canConvertToBarePtr(BaseMemRefType type); @@ -173,28 +174,28 @@ class LLVMTypeConverter : public TypeConverter { /// one. Additionally, if the function returns more than one value, pack the /// results into an LLVM IR structure type so that the converted function type /// returns at most one result. - Type convertFunctionType(FunctionType type); + Type convertFunctionType(FunctionType type) const; /// Convert the index type. Uses llvmModule data layout to create an integer /// of the pointer bitwidth. - Type convertIndexType(IndexType type); + Type convertIndexType(IndexType type) const; /// Convert an integer type `i*` to `!llvm<"i*">`. - Type convertIntegerType(IntegerType type); + Type convertIntegerType(IntegerType type) const; /// Convert a floating point type: `f16` to `f16`, `f32` to /// `f32` and `f64` to `f64`. `bf16` is not supported /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how /// all LLVM backends that support them currently represent them. - Type convertFloatType(FloatType type); + Type convertFloatType(FloatType type) const; /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, /// `complex` to `!llvm<"{ float, float }">`, and `complex` to /// `!llvm<"{ double, double }">`. `complex` is not supported. - Type convertComplexType(ComplexType type); + Type convertComplexType(ComplexType type) const; /// Convert a memref type into an LLVM type that captures the relevant data. - Type convertMemRefType(MemRefType type); + Type convertMemRefType(MemRefType type) const; /// Convert a memref type into a list of LLVM IR types that will form the /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` @@ -218,7 +219,7 @@ class LLVMTypeConverter : public TypeConverter { /// - `i64`, `i64` (strides). /// These types can be recomposed to a memref descriptor struct. SmallVector getMemRefDescriptorFields(MemRefType type, - bool unpackAggregates); + bool unpackAggregates) const; /// Convert an unranked memref type into a list of non-aggregate LLVM IR types /// that will form the unranked memref descriptor. In particular, this list @@ -229,17 +230,17 @@ class LLVMTypeConverter : public TypeConverter { /// i64 (rank) /// !llvm<"i8*"> (type-erased pointer). /// These types can be recomposed to a unranked memref descriptor struct. - SmallVector getUnrankedMemRefDescriptorFields(); + SmallVector getUnrankedMemRefDescriptorFields() const; /// Convert an unranked memref type to an LLVM type that captures the /// runtime rank and a pointer to the static ranked memref desc - Type convertUnrankedMemRefType(UnrankedMemRefType type); + Type convertUnrankedMemRefType(UnrankedMemRefType type) const; /// Convert a memref type to a bare pointer to the memref element type. - Type convertMemRefToBarePtr(BaseMemRefType type); + Type convertMemRefToBarePtr(BaseMemRefType type) const; /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type); + Type convertVectorType(VectorType type) const; /// Options for customizing the llvm lowering. LowerToLLVMOptions options; @@ -252,13 +253,13 @@ class LLVMTypeConverter : public TypeConverter { /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. -LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, +LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. -LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, +LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index d115c2d2f58fef..279175b6128fc7 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -32,7 +32,7 @@ struct NDVectorTypeInfo { // Iterates on the llvm array type until we hit a non-array type (which is // asserted to be an llvm vector type). NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, - LLVMTypeConverter &converter); + const LLVMTypeConverter &converter); // Express `linearIndex` in terms of coordinates of `basis`. // Returns the empty vector when linearIndex is out of the range [0, P] where @@ -50,14 +50,14 @@ void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref)> fun); LogicalResult handleMultidimensionalVectors( - Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, + Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter); LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef targetAttrs, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h index 495c4d63986f80..8bf04219c759ae 100644 --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h @@ -20,7 +20,7 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::getVoidPtrType; explicit AllocationOpLLVMLowering(StringRef opName, - LLVMTypeConverter &converter, + const LLVMTypeConverter &converter, PatternBenefit benefit = 1) : ConvertToLLVMPattern(opName, &converter.getContext(), converter, benefit) {} @@ -107,7 +107,7 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern { /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering { explicit AllocLikeOpLLVMLowering(StringRef opName, - LLVMTypeConverter &converter, + const LLVMTypeConverter &converter, PatternBenefit benefit = 1) : AllocationOpLLVMLowering(opName, converter, benefit) {} diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index ba3e8ae89e1606..89ded981d38f9f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -83,7 +83,7 @@ class SPIRVTypeConverter : public TypeConverter { const SPIRVConversionOptions &getOptions() const { return options; } /// Checks if the SPIR-V capability inquired is supported. - bool allows(spirv::Capability capability); + bool allows(spirv::Capability capability) const; private: spirv::TargetEnv targetEnv; @@ -169,17 +169,17 @@ Value linearizeIndex(ValueRange indices, ArrayRef strides, // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap // that has static strides. Extend to handle dynamic strides. -Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, - Value basePtr, ValueRange indices, Location loc, - OpBuilder &builder); +Value getElementPtr(const SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, ValueRange indices, + Location loc, OpBuilder &builder); // GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V. -Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, +Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder); // GetElementPtr implementation for Vulkan/Shader flavored SPIR-V. -Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter, +Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index b4051093d4b0a9..6e11c3ed0a0179 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -476,13 +476,13 @@ class ConversionPattern : public RewritePattern { /// Return the type converter held by this pattern, or nullptr if the pattern /// does not require type conversion. - TypeConverter *getTypeConverter() const { return typeConverter; } + const TypeConverter *getTypeConverter() const { return typeConverter; } template std::enable_if_t::value, - ConverterTy *> + const ConverterTy *> getTypeConverter() const { - return static_cast(typeConverter); + return static_cast(typeConverter); } protected: @@ -492,13 +492,13 @@ class ConversionPattern : public RewritePattern { /// Construct a conversion pattern with the given converter, and forward the /// remaining arguments to RewritePattern. template - ConversionPattern(TypeConverter &typeConverter, Args &&...args) + ConversionPattern(const TypeConverter &typeConverter, Args &&...args) : RewritePattern(std::forward(args)...), typeConverter(&typeConverter) {} protected: /// An optional type converter for use by this pattern. - TypeConverter *typeConverter = nullptr; + const TypeConverter *typeConverter = nullptr; private: using RewritePattern::rewrite; @@ -514,7 +514,7 @@ class OpConversionPattern : public ConversionPattern { OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} - OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit, context) {} @@ -567,7 +567,7 @@ class OpInterfaceConversionPattern : public ConversionPattern { OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), benefit, context) {} - OpInterfaceConversionPattern(TypeConverter &typeConverter, + OpInterfaceConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), benefit, context) {} @@ -608,17 +608,17 @@ class OpInterfaceConversionPattern : public ConversionPattern { /// ops which use FunctionType to represent their type. void populateFunctionOpInterfaceTypeConversionPattern( StringRef functionLikeOpName, RewritePatternSet &patterns, - TypeConverter &converter); + const TypeConverter &converter); template void populateFunctionOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, TypeConverter &converter) { + RewritePatternSet &patterns, const TypeConverter &converter) { populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(), patterns, converter); } void populateAnyFunctionOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, TypeConverter &converter); + RewritePatternSet &patterns, const TypeConverter &converter); //===----------------------------------------------------------------------===// // Conversion PatternRewriter @@ -645,7 +645,7 @@ class ConversionPatternRewriter final : public PatternRewriter, Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, - TypeConverter *converter = nullptr); + const TypeConverter *converter = nullptr); /// Convert the types of block arguments within the given region. This /// replaces each block with a new block containing the updated signature. The @@ -653,7 +653,7 @@ class ConversionPatternRewriter final : public PatternRewriter, /// provided. On success, the new entry block to the region is returned for /// convenience. Otherwise, failure is returned. FailureOr convertRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); /// Convert the types of block arguments within the given region except for @@ -664,7 +664,7 @@ class ConversionPatternRewriter final : public PatternRewriter, /// example, we need to convert only a subset of a BB arguments), such /// behavior can be specified in blockConversions. LogicalResult convertNonEntryRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, ArrayRef blockConversions); /// Replace all the uses of the block argument `from` with value `to`. @@ -1024,12 +1024,12 @@ class ConversionTarget { class PDLConversionConfig final : public PDLPatternConfigBase { public: - PDLConversionConfig(TypeConverter *converter) : converter(converter) {} + PDLConversionConfig(const TypeConverter *converter) : converter(converter) {} ~PDLConversionConfig() final = default; /// Return the type converter used by this configuration, which may be nullptr /// if no type conversions are expected. - TypeConverter *getTypeConverter() const { return converter; } + const TypeConverter *getTypeConverter() const { return converter; } /// Hooks that are invoked at the beginning and end of a rewrite of a matched /// pattern. @@ -1038,7 +1038,7 @@ class PDLConversionConfig final private: /// An optional type converter to use for the pattern. - TypeConverter *converter; + const TypeConverter *converter; }; /// Register the dialect conversion PDL functions with the given pattern set. diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ecd4cbb25f2d5c..259b7eeb658ec3 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -42,7 +42,7 @@ namespace { /// Define lowering patterns for raw buffer ops template struct RawBufferOpLowering : public ConvertOpToLLVMPattern { - RawBufferOpLowering(LLVMTypeConverter &converter, Chipset chipset) + RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; @@ -345,7 +345,8 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 /// vector. static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, - Location loc, TypeConverter *typeConverter, + Location loc, + const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, SmallVector &operands) { Type inputType = llvmInput.getType(); @@ -384,7 +385,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will /// be stored it in the upper part static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, - Location loc, TypeConverter *typeConverter, + Location loc, + const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector &operands) { Type inputType = output.getType(); @@ -562,7 +564,7 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, namespace { struct MFMAOpLowering : public ConvertOpToLLVMPattern { - MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset) + MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; @@ -600,7 +602,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern { }; struct WMMAOpLowering : public ConvertOpToLLVMPattern { - WMMAOpLowering(LLVMTypeConverter &converter, Chipset chipset) + WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 7c8baee1448575..234d06c08da6dc 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -359,7 +359,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter { /// Creates an LLVM pointer type which may either be a typed pointer or an /// opaque pointer, depending on what options the converter was constructed /// with. - LLVM::LLVMPointerType getPointerType(Type elementType) { + LLVM::LLVMPointerType getPointerType(Type elementType) const { if (llvmOpaquePointers) return LLVM::LLVMPointerType::get(elementType.getContext()); return LLVM::LLVMPointerType::get(elementType); @@ -388,13 +388,14 @@ class AsyncOpConversionPattern : public OpConversionPattern { using Base = OpConversionPattern; public: - AsyncOpConversionPattern(AsyncRuntimeTypeConverter &typeConverter, + AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter, MLIRContext *context) : Base(typeConverter, context) {} /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. - AsyncRuntimeTypeConverter *getTypeConverter() const { - return static_cast(Base::getTypeConverter()); + const AsyncRuntimeTypeConverter *getTypeConverter() const { + return static_cast( + Base::getTypeConverter()); } }; @@ -653,7 +654,7 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypeConverter *converter = getTypeConverter(); + const TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; // Tokens creation maps to a simple function call. @@ -706,7 +707,7 @@ class RuntimeCreateGroupOpLowering LogicalResult matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypeConverter *converter = getTypeConverter(); + const TypeConverter *converter = getTypeConverter(); Type resultType = op.getResult().getType(); rewriter.replaceOpWithNewOp( @@ -1040,8 +1041,8 @@ namespace { template class RefCountingOpLowering : public OpConversionPattern { public: - explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, - StringRef apiFunctionName) + explicit RefCountingOpLowering(const TypeConverter &converter, + MLIRContext *ctx, StringRef apiFunctionName) : OpConversionPattern(converter, ctx), apiFunctionName(apiFunctionName) {} @@ -1065,14 +1066,16 @@ class RefCountingOpLowering : public OpConversionPattern { class RuntimeAddRefOpLowering : public RefCountingOpLowering { public: - explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) + explicit RuntimeAddRefOpLowering(const TypeConverter &converter, + MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kAddRef) {} }; class RuntimeDropRefOpLowering : public RefCountingOpLowering { public: - explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) + explicit RuntimeDropRefOpLowering(const TypeConverter &converter, + MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kDropRef) {} }; } // namespace diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index d99968d78d248c..a4f146bbe475cc 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -46,7 +46,8 @@ static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { /// Generate IR that prints the given string to stderr. static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, - StringRef msg, LLVMTypeConverter &typeConverter) { + StringRef msg, + const LLVMTypeConverter &typeConverter) { auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(moduleOp.getBody()); MLIRContext *ctx = builder.getContext(); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 7ee0ea91827f22..1db463c0ab7163 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -63,7 +63,7 @@ static constexpr StringRef barePtrAttrName = "llvm.bareptr"; /// Return `true` if the `op` should use bare pointer calling convention. static bool shouldUseBarePtrCallConv(Operation *op, - LLVMTypeConverter *typeConverter) { + const LLVMTypeConverter *typeConverter) { return (op && op->hasAttr(barePtrAttrName)) || typeConverter->getOptions().useBarePtrCallConv; } @@ -118,7 +118,7 @@ static void prependEmptyArgAttr(OpBuilder &builder, /// components and forwards them to `newFuncOp` and forwards the results to /// the extra arguments. static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getFunctionType(); @@ -182,7 +182,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, /// compatible with functions defined in C using pointers to C structs /// corresponding to a memref descriptor. static void wrapExternalFunction(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); @@ -281,7 +281,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, /// the bare pointer calling convention lowering of `memref` types. static void modifyFuncOpToUseBarePtrCallingConv( ConversionPatternRewriter &rewriter, Location loc, - LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp, + const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp, TypeRange oldArgTypes) { if (funcOp.getBody().empty()) return; @@ -469,7 +469,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. struct FuncOpConversion : public FuncOpConversionBase { - FuncOpConversion(LLVMTypeConverter &converter) + FuncOpConversion(const LLVMTypeConverter &converter) : FuncOpConversionBase(converter) {} LogicalResult diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index f7caf025fb79bd..2a26587be0b412 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -464,7 +464,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( /// Unrolls op if it's operating on vectors. LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, - LLVMTypeConverter &converter) { + const LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::none_of(operandTypes, [](Type type) { return isa(type); })) { diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index d61f22c9fc37df..bd90286494d803 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -15,8 +15,9 @@ namespace mlir { struct GPUFuncOpLowering : ConvertOpToLLVMPattern { - GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace, - unsigned workgroupAddrSpace, StringAttr kernelAttributeName) + GPUFuncOpLowering(const LLVMTypeConverter &converter, + unsigned allocaAddrSpace, unsigned workgroupAddrSpace, + StringAttr kernelAttributeName) : ConvertOpToLLVMPattern(converter), allocaAddrSpace(allocaAddrSpace), workgroupAddrSpace(workgroupAddrSpace), @@ -57,7 +58,7 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern { /// will lower printf calls to appropriate device-side code struct GPUPrintfOpToLLVMCallLowering : public ConvertOpToLLVMPattern { - GPUPrintfOpToLLVMCallLowering(LLVMTypeConverter &converter, + GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter, int addressSpace = 0) : ConvertOpToLLVMPattern(converter), addressSpace(addressSpace) {} @@ -95,7 +96,7 @@ namespace impl { /// Unrolls op if it's operating on vectors. LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, - LLVMTypeConverter &converter); + const LLVMTypeConverter &converter); } // namespace impl /// Rewriting that unrolls SourceOp to scalars if it's operating on vectors. diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 666dc8e27a9f7d..e0e9a7169bc6b9 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -61,7 +61,8 @@ class GpuToLLVMConversionPass template class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern { public: - explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + explicit ConvertOpToGpuRuntimeCallPattern( + const LLVMTypeConverter &typeConverter) : ConvertOpToLLVMPattern(typeConverter) {} protected: @@ -341,7 +342,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern { class ConvertHostRegisterOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertHostRegisterOpToGpuRuntimeCallPattern( + const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -354,7 +356,7 @@ class ConvertHostUnregisterOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: ConvertHostUnregisterOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) + const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) { } @@ -369,7 +371,7 @@ class ConvertHostUnregisterOpToGpuRuntimeCallPattern class ConvertAllocOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -383,7 +385,8 @@ class ConvertAllocOpToGpuRuntimeCallPattern class ConvertDeallocOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertDeallocOpToGpuRuntimeCallPattern( + const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -395,7 +398,8 @@ class ConvertDeallocOpToGpuRuntimeCallPattern class ConvertAsyncYieldToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertAsyncYieldToGpuRuntimeCallPattern( + const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -409,7 +413,7 @@ class ConvertAsyncYieldToGpuRuntimeCallPattern class ConvertWaitOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -423,7 +427,8 @@ class ConvertWaitOpToGpuRuntimeCallPattern class ConvertWaitAsyncOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertWaitAsyncOpToGpuRuntimeCallPattern( + const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -448,10 +453,9 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern class ConvertLaunchFuncOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, - StringRef gpuBinaryAnnotation, - bool kernelBarePtrCallConv, - SymbolTable *cachedModuleTable) + ConvertLaunchFuncOpToGpuRuntimeCallPattern( + const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation, + bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable) : ConvertOpToGpuRuntimeCallPattern(typeConverter), gpuBinaryAnnotation(gpuBinaryAnnotation), kernelBarePtrCallConv(kernelBarePtrCallConv), @@ -489,7 +493,7 @@ class EraseGpuModuleOpPattern : public OpRewritePattern { class ConvertMemcpyOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -503,7 +507,7 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern class ConvertMemsetOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: - ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} private: @@ -518,7 +522,7 @@ class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { public: ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern( - LLVMTypeConverter &typeConverter) + const LLVMTypeConverter &typeConverter) : ConvertOpToGpuRuntimeCallPattern( typeConverter) {} @@ -534,7 +538,7 @@ class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { \ public: \ Convert##op_name##ToGpuRuntimeCallPattern( \ - LLVMTypeConverter &typeConverter) \ + const LLVMTypeConverter &typeConverter) \ : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} \ \ private: \ @@ -980,15 +984,15 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( SmallVector arguments; if (kernelBarePtrCallConv) { // Hack the bare pointer value on just for the argument promotion - LLVMTypeConverter *converter = getTypeConverter(); + const LLVMTypeConverter *converter = getTypeConverter(); LowerToLLVMOptions options = converter->getOptions(); LowerToLLVMOptions overrideToMatchKernelOpts = options; overrideToMatchKernelOpts.useBarePtrCallConv = true; - converter->dangerousSetOptions(overrideToMatchKernelOpts); - arguments = converter->promoteOperands( + LLVMTypeConverter newConverter = *converter; + newConverter.dangerousSetOptions(overrideToMatchKernelOpts); + arguments = newConverter.promoteOperands( loc, launchOp.getOperands().take_back(numKernelOperands), adaptor.getOperands().take_back(numKernelOperands), builder); - converter->dangerousSetOptions(options); } else { arguments = getTypeConverter()->promoteOperands( loc, launchOp.getOperands().take_back(numKernelOperands), @@ -1111,15 +1115,15 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( SmallVector arguments; if (kernelBarePtrCallConv) { // Hack the bare pointer value on just for the argument promotion - LLVMTypeConverter *converter = getTypeConverter(); + const LLVMTypeConverter *converter = getTypeConverter(); LowerToLLVMOptions options = converter->getOptions(); LowerToLLVMOptions overrideToMatchKernelOpts = options; overrideToMatchKernelOpts.useBarePtrCallConv = true; - converter->dangerousSetOptions(overrideToMatchKernelOpts); + LLVMTypeConverter newConverter = *converter; + newConverter.dangerousSetOptions(overrideToMatchKernelOpts); arguments = - converter->promoteOperands(loc, launchOp.getKernelOperands(), - adaptor.getKernelOperands(), rewriter); - converter->dangerousSetOptions(options); + newConverter.promoteOperands(loc, launchOp.getKernelOperands(), + adaptor.getKernelOperands(), rewriter); } else { arguments = getTypeConverter()->promoteOperands( loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), @@ -1200,7 +1204,7 @@ static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, - LLVMTypeConverter &typeConverter) { + const LLVMTypeConverter &typeConverter) { auto sourceTy = cast(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) sourcePtr = rewriter.create( diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index feea1e34f1b43b..693cc3f6236b57 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -222,7 +222,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite( // Legalizes a GPU function as an entry SPIR-V function. static spirv::FuncOp -lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter, +lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo) { diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index 57e21530b9da76..3851fb728b6654 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -88,7 +88,7 @@ struct WmmaLoadOpToSPIRVLowering auto memrefType = cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()); Value bufferPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, + *getTypeConverter(), memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); auto coopType = convertMMAToSPIRVType(retType); int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue(); @@ -119,7 +119,7 @@ struct WmmaStoreOpToSPIRVLowering auto memrefType = cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()); Value bufferPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, + *getTypeConverter(), memrefType, adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue(); auto i32Type = rewriter.getI32Type(); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index 2c9580e421340a..0a3c9a57eec95d 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -41,13 +41,13 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, /// type. MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { return fromStaticShape(builder, loc, typeConverter, type, memory, memory); } MemRefDescriptor MemRefDescriptor::fromStaticShape( - OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory, Value alignedMemory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); @@ -198,7 +198,7 @@ LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { } Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, + const LLVMTypeConverter &converter, MemRefType type) { // When we convert to LLVM, the input memref must have been normalized // beforehand. Hence, this call is guaranteed to work. @@ -230,8 +230,8 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, /// - shapes; /// where is the MemRef rank as provided in `type`. Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, MemRefType type, - ValueRange values) { + const LLVMTypeConverter &converter, + MemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = MemRefDescriptor::undef(builder, loc, llvmType); @@ -340,7 +340,7 @@ void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, /// - rank of the memref; /// - pointer to the memref descriptor. Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, + const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); @@ -363,7 +363,7 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, } void UnrankedMemRefDescriptor::computeSizes( - OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef values, ArrayRef addressSpaces, SmallVectorImpl &sizes) { if (values.empty()) @@ -453,10 +453,9 @@ castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, return {elementPtrPtr, elemPtrPtrType}; } -Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrType) { +Value UnrankedMemRefDescriptor::alignedPtr( + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); @@ -466,11 +465,9 @@ Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, return builder.create(loc, elemPtrType, alignedGep); } -void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrType, - Value alignedPtr) { +void UnrankedMemRefDescriptor::setAlignedPtr( + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr) { auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); @@ -481,7 +478,7 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, } Value UnrankedMemRefDescriptor::offsetBasePtr( - OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); @@ -499,7 +496,7 @@ Value UnrankedMemRefDescriptor::offsetBasePtr( } Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { Value offsetPtr = @@ -509,7 +506,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset) { @@ -518,10 +515,9 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, builder.create(loc, offset, offsetPtr); } -Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrType) { +Value UnrankedMemRefDescriptor::sizeBasePtr( + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { Type indexTy = typeConverter.getIndexType(); Type structTy = LLVM::LLVMStructType::getLiteral( indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); @@ -542,7 +538,7 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, } Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index) { Type indexTy = typeConverter.getIndexType(); @@ -554,7 +550,7 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, } void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size) { Type indexTy = typeConverter.getIndexType(); @@ -565,9 +561,9 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, builder.create(loc, size, sizeStoreGep); } -Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value sizeBasePtr, Value rank) { +Value UnrankedMemRefDescriptor::strideBasePtr( + OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank) { Type indexTy = typeConverter.getIndexType(); Type indexPtrTy = typeConverter.getPointerType(indexTy); @@ -576,7 +572,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, } Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride) { Type indexTy = typeConverter.getIndexType(); @@ -588,7 +584,7 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, } void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride) { Type indexTy = typeConverter.getIndexType(); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 1699172eb9dab3..e5519df9b0185f 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -19,14 +19,13 @@ using namespace mlir; // ConvertToLLVMPattern //===----------------------------------------------------------------------===// -ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, - MLIRContext *context, - LLVMTypeConverter &typeConverter, - PatternBenefit benefit) +ConvertToLLVMPattern::ConvertToLLVMPattern( + StringRef rootOpName, MLIRContext *context, + const LLVMTypeConverter &typeConverter, PatternBenefit benefit) : ConversionPattern(typeConverter, rootOpName, benefit, context) {} -LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { - return static_cast( +const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { + return static_cast( ConversionPattern::getTypeConverter()); } @@ -337,10 +336,12 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult LLVM::detail::oneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { +LogicalResult +LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector resultTypes; diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 9e03e2ffbacf83..b0842b9972c76d 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -166,34 +166,35 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, } /// Returns the MLIR context. -MLIRContext &LLVMTypeConverter::getContext() { +MLIRContext &LLVMTypeConverter::getContext() const { return *getDialect()->getContext(); } -Type LLVMTypeConverter::getIndexType() { +Type LLVMTypeConverter::getIndexType() const { return IntegerType::get(&getContext(), getIndexTypeBitwidth()); } LLVM::LLVMPointerType -LLVMTypeConverter::getPointerType(Type elementType, unsigned int addressSpace) { +LLVMTypeConverter::getPointerType(Type elementType, + unsigned int addressSpace) const { if (useOpaquePointers()) return LLVM::LLVMPointerType::get(&getContext(), addressSpace); return LLVM::LLVMPointerType::get(elementType, addressSpace); } -unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { +unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const { return options.dataLayout.getPointerSizeInBits(addressSpace); } -Type LLVMTypeConverter::convertIndexType(IndexType type) { +Type LLVMTypeConverter::convertIndexType(IndexType type) const { return getIndexType(); } -Type LLVMTypeConverter::convertIntegerType(IntegerType type) { +Type LLVMTypeConverter::convertIntegerType(IntegerType type) const { return IntegerType::get(&getContext(), type.getWidth()); } -Type LLVMTypeConverter::convertFloatType(FloatType type) { +Type LLVMTypeConverter::convertFloatType(FloatType type) const { if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) return IntegerType::get(&getContext(), type.getWidth()); @@ -204,7 +205,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) { // struct with entries for the // 1. real part and for the // 2. imaginary part. -Type LLVMTypeConverter::convertComplexType(ComplexType type) { +Type LLVMTypeConverter::convertComplexType(ComplexType type) const { auto elementType = convertType(type.getElementType()); return LLVM::LLVMStructType::getLiteral(&getContext(), {elementType, elementType}); @@ -212,7 +213,7 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) { // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. -Type LLVMTypeConverter::convertFunctionType(FunctionType type) { +Type LLVMTypeConverter::convertFunctionType(FunctionType type) const { SignatureConversion conversion(type.getNumInputs()); Type converted = convertFunctionSignature( type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion); @@ -227,7 +228,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) { // they are into an LLVM StructType in their order of appearance. Type LLVMTypeConverter::convertFunctionSignature( FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, - LLVMTypeConverter::SignatureConversion &result) { + LLVMTypeConverter::SignatureConversion &result) const { // Select the argument converter depending on the calling convention. useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter @@ -256,7 +257,7 @@ Type LLVMTypeConverter::convertFunctionSignature( /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. std::pair -LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { +LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { SmallVector inputs; Type resultType = type.getNumResults() == 0 @@ -315,7 +316,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { /// }; SmallVector LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, - bool unpackAggregates) { + bool unpackAggregates) const { if (!isStrided(type)) { emitError( UnknownLoc::get(type.getContext()), @@ -353,8 +354,9 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, return results; } -unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, - const DataLayout &layout) { +unsigned +LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, + const DataLayout &layout) const { // Compute the descriptor size given that of its components indicated above. unsigned space = *getMemRefAddressSpace(type); return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + @@ -363,7 +365,7 @@ unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. -Type LLVMTypeConverter::convertMemRefType(MemRefType type) { +Type LLVMTypeConverter::convertMemRefType(MemRefType type) const { // When converting a MemRefType to a struct with descriptor fields, do not // unpack the `sizes` and `strides` arrays. SmallVector types = @@ -380,20 +382,21 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be /// stack allocated (alloca) copy of a MemRef descriptor that got casted to /// be unranked. -SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { +SmallVector +LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const { return {getIndexType(), getPointerType(IntegerType::get(&getContext(), 8))}; } -unsigned -LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, - const DataLayout &layout) { +unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize( + UnrankedMemRefType type, const DataLayout &layout) const { // Compute the descriptor size given that of its components indicated above. unsigned space = *getMemRefAddressSpace(type); return layout.getTypeSize(getIndexType()) + llvm::divideCeil(getPointerBitwidth(space), 8); } -Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { +Type LLVMTypeConverter::convertUnrankedMemRefType( + UnrankedMemRefType type) const { if (!convertType(type.getElementType())) return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), @@ -401,7 +404,7 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { } FailureOr -LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) { +LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const { if (!type.getMemorySpace()) // Default memory space -> 0. return 0; std::optional converted = @@ -440,7 +443,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { } /// Convert a memref type to a bare pointer to the memref element type. -Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { +Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { if (!canConvertToBarePtr(type)) return {}; Type elementType = convertType(type.getElementType()); @@ -460,7 +463,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { /// As LLVM does not support arrays of scalable vectors, it is assumed that /// scalable vectors are always 1-D. This condition could be relaxed once the /// missing functionality is added in LLVM -Type LLVMTypeConverter::convertVectorType(VectorType type) { +Type LLVMTypeConverter::convertVectorType(VectorType type) const { auto elementType = convertType(type.getElementType()); if (!elementType) return {}; @@ -484,8 +487,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) { /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType(Type type, - bool useBarePtrCallConv) { +Type LLVMTypeConverter::convertCallingConventionType( + Type type, bool useBarePtrCallConv) const { if (useBarePtrCallConv) if (auto memrefTy = dyn_cast(type)) return convertMemRefToBarePtr(memrefTy); @@ -498,7 +501,7 @@ Type LLVMTypeConverter::convertCallingConventionType(Type type, /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). void LLVMTypeConverter::promoteBarePtrsToDescriptors( ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, - SmallVectorImpl &values) { + SmallVectorImpl &values) const { assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) @@ -511,7 +514,7 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors( /// LLVM-compatible type. In particular, if more than one value is /// produced, create a literal structure with elements that correspond to each /// of the types converted with `convertType`. -Type LLVMTypeConverter::packOperationResults(TypeRange types) { +Type LLVMTypeConverter::packOperationResults(TypeRange types) const { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertType(types[0]); @@ -533,7 +536,7 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) { /// create an LLVM dialect structure type with elements that correspond to each /// of the types converted with `convertCallingConventionType`. Type LLVMTypeConverter::packFunctionResults(TypeRange types, - bool useBarePtrCallConv) { + bool useBarePtrCallConv) const { assert(!types.empty() && "expected non-empty list of type"); useBarePtrCallConv |= options.useBarePtrCallConv; @@ -553,7 +556,7 @@ Type LLVMTypeConverter::packFunctionResults(TypeRange types, } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, - OpBuilder &builder) { + OpBuilder &builder) const { // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = getPointerType(operand.getType()); @@ -569,7 +572,7 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, SmallVector LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, - bool useBarePtrCallConv) { + bool useBarePtrCallConv) const { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); useBarePtrCallConv |= options.useBarePtrCallConv; @@ -608,9 +611,9 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. -LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type, - SmallVectorImpl &result) { +LogicalResult +mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, + SmallVectorImpl &result) { if (auto memref = dyn_cast(type)) { // In signatures, Memref descriptors are expanded into lists of // non-aggregate values. @@ -637,9 +640,9 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. -LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type, - SmallVectorImpl &result) { +LogicalResult +mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, + SmallVectorImpl &result) { auto llvmTy = converter.convertCallingConventionType( type, /*useBarePointerCallConv=*/true); if (!llvmTy) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index 732f6c578c8b57..544bcc71aca1b5 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -17,7 +17,7 @@ using namespace mlir; // asserted to be an llvm vector type). LLVM::detail::NDVectorTypeInfo LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, - LLVMTypeConverter &converter) { + const LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; info.llvmNDVectorTy = converter.convertType(vectorType); @@ -78,7 +78,7 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, } LogicalResult LLVM::detail::handleMultidimensionalVectors( - Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, + Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { auto resultNDVectorType = cast(op->getResult(0).getType()); @@ -103,10 +103,12 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( return success(); } -LogicalResult LLVM::detail::vectorOneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { +LogicalResult +LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef targetAttrs, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index 715d00f2e215ac..a2a426e3c29317 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -18,7 +18,7 @@ namespace { // with SymbolTable trait instead of ModuleOp and make similar change here. This // allows call sites to use getParentWithTrait instead // of getParentOfType to pass down the operation. -LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter, +LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, ModuleOp module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; @@ -30,7 +30,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter, typeConverter->useOpaquePointers()); } -LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter, +LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter, ModuleOp module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; @@ -57,7 +57,7 @@ Value AllocationOpLLVMLowering::createAligned( static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, - LLVMTypeConverter &typeConverter) { + const LLVMTypeConverter &typeConverter) { auto allocatedPtrTy = cast(allocatedPtr.getType()); unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType); if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 4b27dcb6cda281..8843ab78eed782 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -41,7 +41,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) { return !ShapedType::isDynamic(strideOrOffset); } -LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) { +LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter, + ModuleOp module) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -52,7 +53,7 @@ LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) { } struct AllocOpLowering : public AllocLikeOpLLVMLowering { - AllocOpLowering(LLVMTypeConverter &converter) + AllocOpLowering(const LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, @@ -65,7 +66,7 @@ struct AllocOpLowering : public AllocLikeOpLLVMLowering { }; struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { - AlignedAllocOpLowering(LLVMTypeConverter &converter) + AlignedAllocOpLowering(const LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, @@ -84,7 +85,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { }; struct AllocaOpLowering : public AllocLikeOpLLVMLowering { - AllocaOpLowering(LLVMTypeConverter &converter) + AllocaOpLowering(const LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), converter) { setRequiresNumElements(); @@ -122,7 +123,7 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering { struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { using OpAdaptor = typename memref::ReallocOp::Adaptor; - ReallocOpLoweringBase(LLVMTypeConverter &converter) + ReallocOpLoweringBase(const LLVMTypeConverter &converter) : AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(), converter) {} @@ -247,7 +248,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { }; struct ReallocOpLowering : public ReallocOpLoweringBase { - ReallocOpLowering(LLVMTypeConverter &converter) + ReallocOpLowering(const LLVMTypeConverter &converter) : ReallocOpLoweringBase(converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, @@ -258,7 +259,7 @@ struct ReallocOpLowering : public ReallocOpLoweringBase { }; struct AlignedReallocOpLowering : public ReallocOpLoweringBase { - AlignedReallocOpLowering(LLVMTypeConverter &converter) + AlignedReallocOpLowering(const LLVMTypeConverter &converter) : ReallocOpLoweringBase(converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, @@ -334,7 +335,7 @@ struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; - explicit AssumeAlignmentOpLowering(LLVMTypeConverter &converter) + explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult @@ -376,7 +377,7 @@ struct AssumeAlignmentOpLowering struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - explicit DeallocOpLowering(LLVMTypeConverter &converter) + explicit DeallocOpLowering(const LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult @@ -635,8 +636,9 @@ struct GenericAtomicRMWOpLowering }; /// Returns the LLVM type of the global variable given the memref type `type`. -static Type convertGlobalMemrefTypeToLLVM(MemRefType type, - LLVMTypeConverter &typeConverter) { +static Type +convertGlobalMemrefTypeToLLVM(MemRefType type, + const LLVMTypeConverter &typeConverter) { // LLVM type for a global memref will be a multi-dimension array. For // declarations or uninitialized global memrefs, we can potentially flatten // this to a 1D array. However, for memref.global's with an initial value, @@ -703,7 +705,7 @@ struct GlobalMemrefOpLowering /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { - GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) + GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), converter) {} @@ -1191,7 +1193,7 @@ struct MemorySpaceCastOpLowering /// ranked descriptor. static void extractPointersAndOffset(Location loc, ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, Value originalOperand, Value convertedOperand, Value *allocatedPtr, Value *alignedPtr, diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 1d85e64bdfbfc3..f024bdfda93888 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -61,10 +61,10 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, /// 1D array (spirv.array or spirv.rt_array), the last index is modified to load /// the bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. -static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, - spirv::AccessChainOp op, - int sourceBits, int targetBits, - OpBuilder &builder) { +static Value +adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, + spirv::AccessChainOp op, int sourceBits, + int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); const auto loc = op.getLoc(); IntegerType targetType = builder.getIntegerType(targetBits); @@ -277,7 +277,7 @@ class CastPattern final : public OpConversionPattern { Value src = adaptor.getSource(); Type srcType = src.getType(); - TypeConverter *converter = getTypeConverter(); + const TypeConverter *converter = getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (srcType != dstType) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { @@ -436,7 +436,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, if (!memrefType.getElementType().isSignlessInteger()) return failure(); - auto &typeConverter = *getTypeConverter(); + const auto &typeConverter = *getTypeConverter(); Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), adaptor.getIndices(), loc, rewriter); @@ -768,7 +768,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( diag << "invalid src type " << src.getType(); }); - TypeConverter *converter = getTypeConverter(); + const TypeConverter *converter = getTypeConverter(); auto dstType = converter->convertType(op.getType()); if (dstType != srcType) diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 64394de91d4dd7..21c6780cc7887f 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -236,7 +236,7 @@ MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context, /// Returns the base pointer of the mbarrier object. static Value getMbarrierPtr(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, TypedValue barrier, Value barrierMemref) { MemRefType memrefType = diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index fb8ad5a4c31f54..d06b7033257196 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -58,7 +58,7 @@ struct RegionLessOpWithVarOperandsConversion LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); @@ -90,7 +90,7 @@ struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); @@ -128,7 +128,7 @@ struct RegionLessOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); @@ -145,7 +145,7 @@ struct AtomicReadOpConversion LogicalResult matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); Type curElementType = curOp.getElementType(); auto newOp = rewriter.create( curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index e325348242affe..92f7aa69760395 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -37,7 +37,7 @@ static VectorType reducedVectorTypeBack(VectorType tp) { // Helper that picks the proper sequence for inserting. static Value insertOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, + const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos) { assert(rank > 0 && "0-D vector corner case should have been handled already"); @@ -54,7 +54,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter, // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, + const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank <= 1) { auto idxType = rewriter.getIndexType(); @@ -68,7 +68,7 @@ static Value extractOne(ConversionPatternRewriter &rewriter, } // Helper that returns data layout alignment of a memref. -LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, +LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { Type elementTy = typeConverter.convertType(memrefType.getElementType()); if (!elementTy) @@ -84,7 +84,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, // Check if the last stride is non-unit or the memory space is not zero. static LogicalResult isMemRefTypeSupported(MemRefType memRefType, - LLVMTypeConverter &converter) { + const LLVMTypeConverter &converter) { if (!isLastMemrefDimUnitStride(memRefType)) return failure(); FailureOr addressSpace = @@ -96,7 +96,7 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType, // Add an index vector component to a base pointer. static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, uint64_t vLen) { assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && @@ -112,7 +112,7 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, // will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr, MemRefType memRefType, Type vt, - LLVMTypeConverter &converter) { + const LLVMTypeConverter &converter) { if (converter.useOpaquePointers()) return ptr; @@ -294,7 +294,7 @@ class VectorGatherOpConversion return success(); } - LLVMTypeConverter &typeConverter = *this->getTypeConverter(); + const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); auto callback = [align, memRefType, base, ptr, loc, &rewriter, &typeConverter](Type llvm1DVectorTy, ValueRange vectorOperands) { @@ -672,7 +672,7 @@ static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, class VectorReductionOpConversion : public ConvertOpToLLVMPattern { public: - explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, + explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv, bool reassociateFPRed) : ConvertOpToLLVMPattern(typeConv), reassociateFPReductions(reassociateFPRed) {} diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index c19f8f182a923d..1355af14660776 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -24,7 +24,7 @@ namespace { /// dimension directly translates into the number of rows of the tiles. /// The second dimensions needs to be scaled by the number of bytes. std::pair getTileSizes(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, + const LLVMTypeConverter &typeConverter, VectorType vType, Location loc) { Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); unsigned width = vType.getElementType().getIntOrFloatBitWidth(); @@ -52,8 +52,8 @@ LogicalResult verifyStride(MemRefType mType) { /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. Value getStride(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, MemRefType mType, Value base, - Location loc) { + const LLVMTypeConverter &typeConverter, MemRefType mType, + Value base, Location loc) { assert(mType.getRank() >= 2); int64_t last = mType.getRank() - 1; Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index e768940cc27b5b..2b654db87fe4ff 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -80,7 +80,7 @@ LogicalResult EmulateFloatPattern::match(Operation *op) const { void EmulateFloatPattern::rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); - TypeConverter *converter = getTypeConverter(); + const TypeConverter *converter = getTypeConverter(); SmallVector resultTypes; if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) { // Note to anyone looking for this error message: this is a "can't happen". diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 94ce4ebb812947..c75d217663a9e0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -132,7 +132,7 @@ MLIRContext *SPIRVTypeConverter::getContext() const { return targetEnv.getAttr().getContext(); } -bool SPIRVTypeConverter::allows(spirv::Capability capability) { +bool SPIRVTypeConverter::allows(spirv::Capability capability) const { return targetEnv.allows(capability); } @@ -992,7 +992,7 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, return linearizedIndex; } -Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter, +Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { @@ -1023,7 +1023,7 @@ Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter, return builder.create(loc, basePtr, linearizedIndices); } -Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, +Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { @@ -1058,7 +1058,7 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, linearizedIndices); } -Value mlir::spirv::getElementPtr(SPIRVTypeConverter &typeConverter, +Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index b36f2978d20e38..c33304c18fe48a 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -40,8 +40,8 @@ struct LowerToIntrinsic : public OpConversionPattern { explicit LowerToIntrinsic(LLVMTypeConverter &converter) : OpConversionPattern(converter, &converter.getContext()) {} - LLVMTypeConverter &getTypeConverter() const { - return *static_cast( + const LLVMTypeConverter &getTypeConverter() const { + return *static_cast( OpConversionPattern::getTypeConverter()); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index fa75d6efa15bb2..78d7b47558b553 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -226,11 +226,12 @@ class OperationTransactionState { /// This class represents one requested operation replacement via 'replaceOp' or /// 'eraseOp`. struct OpReplacement { - OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {} + OpReplacement(const TypeConverter *converter = nullptr) + : converter(converter) {} /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. - TypeConverter *converter; + const TypeConverter *converter; }; //===----------------------------------------------------------------------===// @@ -333,7 +334,7 @@ class UnresolvedMaterialization { }; UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr, - TypeConverter *converter = nullptr, + const TypeConverter *converter = nullptr, Kind kind = Target, Type origOutputType = nullptr) : op(op), converterAndKind(converter, kind), origOutputType(origOutputType) {} @@ -343,7 +344,9 @@ class UnresolvedMaterialization { UnrealizedConversionCastOp getOp() const { return op; } /// Return the type converter of this materialization (which may be null). - TypeConverter *getConverter() const { return converterAndKind.getPointer(); } + const TypeConverter *getConverter() const { + return converterAndKind.getPointer(); + } /// Return the kind of this materialization. Kind getKind() const { return converterAndKind.getInt(); } @@ -360,7 +363,7 @@ class UnresolvedMaterialization { /// The corresponding type converter to use when resolving this /// materialization, and the kind of this materialization. - llvm::PointerIntPair converterAndKind; + llvm::PointerIntPair converterAndKind; /// The original output type. This is only used for argument conversions. Type origOutputType; @@ -372,7 +375,7 @@ class UnresolvedMaterialization { static Value buildUnresolvedMaterialization( UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, - Type origOutputType, TypeConverter *converter, + Type origOutputType, const TypeConverter *converter, SmallVectorImpl &unresolvedMaterializations) { // Avoid materializing an unnecessary cast. if (inputs.size() == 1 && inputs.front().getType() == outputType) @@ -389,7 +392,7 @@ static Value buildUnresolvedMaterialization( } static Value buildUnresolvedArgumentMaterialization( PatternRewriter &rewriter, Location loc, ValueRange inputs, - Type origOutputType, Type outputType, TypeConverter *converter, + Type origOutputType, Type outputType, const TypeConverter *converter, SmallVectorImpl &unresolvedMaterializations) { return buildUnresolvedMaterialization( UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(), @@ -397,7 +400,7 @@ static Value buildUnresolvedArgumentMaterialization( converter, unresolvedMaterializations); } static Value buildUnresolvedTargetMaterialization( - Location loc, Value input, Type outputType, TypeConverter *converter, + Location loc, Value input, Type outputType, const TypeConverter *converter, SmallVectorImpl &unresolvedMaterializations) { Block *insertBlock = input.getParentBlock(); Block::iterator insertPt = insertBlock->begin(); @@ -446,7 +449,7 @@ struct ArgConverter { /// This structure contains information pertaining to a block that has had its /// signature converted. struct ConvertedBlockInfo { - ConvertedBlockInfo(Block *origBlock, TypeConverter *converter) + ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter) : origBlock(origBlock), converter(converter) {} /// The original block that was requested to have its signature converted. @@ -457,7 +460,7 @@ struct ArgConverter { SmallVector, 1> argInfo; /// The type converter used to convert the arguments. - TypeConverter *converter; + const TypeConverter *converter; }; /// Return if the signature of the given block has already been converted. @@ -466,14 +469,14 @@ struct ArgConverter { } /// Set the type converter to use for the given region. - void setConverter(Region *region, TypeConverter *typeConverter) { + void setConverter(Region *region, const TypeConverter *typeConverter) { assert(typeConverter && "expected valid type converter"); regionToConverter[region] = typeConverter; } /// Return the type converter to use for the given region, or null if there /// isn't one. - TypeConverter *getConverter(Region *region) { + const TypeConverter *getConverter(Region *region) { return regionToConverter.lookup(region); } @@ -510,7 +513,7 @@ struct ArgConverter { /// block is returned containing the new arguments. Returns `block` if it did /// not require conversion. FailureOr - convertSignature(Block *block, TypeConverter *converter, + convertSignature(Block *block, const TypeConverter *converter, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements); @@ -521,7 +524,7 @@ struct ArgConverter { /// translate between the origin argument types and those specified in the /// signature conversion. Block *applySignatureConversion( - Block *block, TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements); @@ -542,7 +545,7 @@ struct ArgConverter { /// A mapping of regions to type converters that should be used when /// converting the arguments of blocks within that region. - DenseMap regionToConverter; + DenseMap regionToConverter; /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; @@ -686,7 +689,8 @@ LogicalResult ArgConverter::materializeLiveConversions( // Conversion FailureOr ArgConverter::convertSignature( - Block *block, TypeConverter *converter, ConversionValueMapping &mapping, + Block *block, const TypeConverter *converter, + ConversionValueMapping &mapping, SmallVectorImpl &argReplacements) { // Check if the block was already converted. If the block is detached, // conservatively assume it is going to be deleted. @@ -705,7 +709,7 @@ FailureOr ArgConverter::convertSignature( } Block *ArgConverter::applySignatureConversion( - Block *block, TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements) { @@ -865,7 +869,7 @@ struct ConversionPatternRewriterImpl { /// Convert the signature of the given block. FailureOr convertBlockSignature( - Block *block, TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion = nullptr); /// Apply a signature conversion on the given region, using `converter` for @@ -873,16 +877,16 @@ struct ConversionPatternRewriterImpl { Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, - TypeConverter *converter); + const TypeConverter *converter); /// Convert the types of block arguments within the given region. FailureOr - convertRegionTypes(Region *region, TypeConverter &converter, + convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); /// Convert the types of non-entry block arguments within the given region. LogicalResult convertNonEntryRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, ArrayRef blockConversions = {}); //===--------------------------------------------------------------------===// @@ -962,7 +966,7 @@ struct ConversionPatternRewriterImpl { /// The current type converter, or nullptr if no type converter is currently /// active. - TypeConverter *currentTypeConverter = nullptr; + const TypeConverter *currentTypeConverter = nullptr; /// This allows the user to collect the match failure message. function_ref notifyCallback; @@ -1283,7 +1287,7 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { // Type Conversion FailureOr ConversionPatternRewriterImpl::convertBlockSignature( - Block *block, TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion) { FailureOr result = conversion ? argConverter.applySignatureConversion( @@ -1301,14 +1305,14 @@ FailureOr ConversionPatternRewriterImpl::convertBlockSignature( Block *ConversionPatternRewriterImpl::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion, - TypeConverter *converter) { + const TypeConverter *converter) { if (!region->empty()) return *convertBlockSignature(®ion->front(), converter, &conversion); return nullptr; } FailureOr ConversionPatternRewriterImpl::convertRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { argConverter.setConverter(region, &converter); if (region->empty()) @@ -1323,7 +1327,7 @@ FailureOr ConversionPatternRewriterImpl::convertRegionTypes( } LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, ArrayRef blockConversions) { argConverter.setConverter(region, &converter); if (region->empty()) @@ -1492,18 +1496,18 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { Block *ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion, - TypeConverter *converter) { + const TypeConverter *converter) { return impl->applySignatureConversion(region, conversion, converter); } FailureOr ConversionPatternRewriter::convertRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { return impl->convertRegionTypes(region, converter, entryConversion); } LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( - Region *region, TypeConverter &converter, + Region *region, const TypeConverter &converter, ArrayRef blockConversions) { return impl->convertNonEntryRegionTypes(region, converter, blockConversions); } @@ -2341,7 +2345,7 @@ struct OperationConverter { /// type. LogicalResult legalizeChangedResultType( Operation *op, OpResult result, Value newValue, - TypeConverter *replConverter, ConversionPatternRewriter &rewriter, + const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap> &inverseMapping); @@ -2717,7 +2721,7 @@ static LogicalResult legalizeUnresolvedMaterialization( } // Try to materialize the conversion. - if (TypeConverter *converter = mat.getConverter()) { + if (const TypeConverter *converter = mat.getConverter()) { // FIXME: Determine a suitable insertion location when there are multiple // inputs. if (inputOperands.size() == 1) @@ -2836,7 +2840,7 @@ static Operation *findLiveUserOfReplaced( LogicalResult OperationConverter::legalizeChangedResultType( Operation *op, OpResult result, Value newValue, - TypeConverter *replConverter, ConversionPatternRewriter &rewriter, + const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap> &inverseMapping) { Operation *liveUser = @@ -3075,7 +3079,7 @@ TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { //===----------------------------------------------------------------------===// static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, - TypeConverter &typeConverter, + const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { FunctionType type = dyn_cast(funcOp.getFunctionType()); if (!type) @@ -3106,7 +3110,7 @@ namespace { struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, MLIRContext *ctx, - TypeConverter &converter) + const TypeConverter &converter) : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} LogicalResult @@ -3131,13 +3135,13 @@ struct AnyFunctionOpInterfaceSignatureConversion void mlir::populateFunctionOpInterfaceTypeConversionPattern( StringRef functionLikeOpName, RewritePatternSet &patterns, - TypeConverter &converter) { + const TypeConverter &converter) { patterns.add( functionLikeOpName, patterns.getContext(), converter); } void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( - RewritePatternSet &patterns, TypeConverter &converter) { + RewritePatternSet &patterns, const TypeConverter &converter) { patterns.add( converter, patterns.getContext()); } @@ -3338,7 +3342,8 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { [](PatternRewriter &rewriter, Type type) -> FailureOr { auto &rewriterImpl = static_cast(rewriter).getImpl(); - if (TypeConverter *converter = rewriterImpl.currentTypeConverter) { + if (const TypeConverter *converter = + rewriterImpl.currentTypeConverter) { if (Type newType = converter->convertType(type)) return newType; return failure(); @@ -3351,7 +3356,7 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { TypeRange types) -> FailureOr> { auto &rewriterImpl = static_cast(rewriter).getImpl(); - TypeConverter *converter = rewriterImpl.currentTypeConverter; + const TypeConverter *converter = rewriterImpl.currentTypeConverter; if (!converter) return SmallVector(types); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 46788edcb4df58..30ed4109ad8bd1 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -668,7 +668,8 @@ struct TestUndoBlockErase : public ConversionPattern { /// This patterns erases a region operation that has had a type conversion. struct TestDropOpSignatureConversion : public ConversionPattern { - TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + TestDropOpSignatureConversion(MLIRContext *ctx, + const TypeConverter &converter) : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -677,7 +678,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern { Block *entry = ®ion.front(); // Convert the original entry arguments. - TypeConverter &converter = *getTypeConverter(); + const TypeConverter &converter = *getTypeConverter(); TypeConverter::SignatureConversion result(entry->getNumArguments()); if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), result)) || @@ -1307,7 +1308,7 @@ struct TestSignatureConversionUndo /// materializations. struct TestTestSignatureConversionNoConverter : public OpConversionPattern { - TestTestSignatureConversionNoConverter(TypeConverter &converter, + TestTestSignatureConversionNoConverter(const TypeConverter &converter, MLIRContext *context) : OpConversionPattern(context), converter(converter) {} @@ -1328,7 +1329,7 @@ struct TestTestSignatureConversionNoConverter return success(); } - TypeConverter &converter; + const TypeConverter &converter; }; /// Just forward the operands to the root op. This is essentially a no-op