Skip to content

Commit

Permalink
[mlir][Conversion] Store const type converter in ConversionPattern
Browse files Browse the repository at this point in the history
ConversionPatterns do not (and should not) modify the type converter that they are using.

* Make `ConversionPattern::typeConverter` const.
* Make member functions of the `LLVMTypeConverter` const.
* Conversion patterns take a const type converter.
* Various helper functions (that are called from patterns) now also take a const type converter.

Differential Revision: https://reviews.llvm.org/D157601
  • Loading branch information
matthias-springer committed Aug 14, 2023
1 parent ce16c3c commit ce25459
Show file tree
Hide file tree
Showing 35 changed files with 383 additions and 358 deletions.
36 changes: 17 additions & 19 deletions flang/include/flang/Optimizer/CodeGen/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,56 +49,54 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {

// i32 is used here because LLVM wants i32 constants when indexing into struct
// types. Indexing into other aggregate types is more flexible.
mlir::Type offsetType();
mlir::Type offsetType() const;

// i64 can be used to index into aggregates like arrays
mlir::Type indexType();
mlir::Type indexType() const;

// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult>
convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack);
llvm::ArrayRef<mlir::Type> callStack) const;

// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
bool requiresExtendedDesc(mlir::Type boxElementType);
bool requiresExtendedDesc(mlir::Type boxElementType) const;

// Magic value to indicate we do not know the rank of an entity, either
// because it is assumed rank or because we have not determined it yet.
static constexpr int unknownRank() { return -1; }

// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
// the addendum defined in descriptor.h.
mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank());
mlir::Type convertBoxType(BaseBoxType box, int rank = unknownRank()) const;

/// Convert fir.box type to the corresponding llvm struct type instead of a
/// pointer to this struct type.
mlir::Type convertBoxTypeAsStruct(BaseBoxType box);
mlir::Type convertBoxTypeAsStruct(BaseBoxType box) const;

// fir.boxproc<any> --> llvm<"{ any*, i8* }">
mlir::Type convertBoxProcType(BoxProcType boxproc);
mlir::Type convertBoxProcType(BoxProcType boxproc) const;

unsigned characterBitsize(fir::CharacterType charTy);
unsigned characterBitsize(fir::CharacterType charTy) const;

// fir.char<k,?> --> llvm<"ix"> where ix is scaled by kind mapping
// fir.char<k,n> --> llvm.array<n x "ix">
mlir::Type convertCharType(fir::CharacterType charTy);
mlir::Type convertCharType(fir::CharacterType charTy) const;

// Use the target specifics to figure out how to map complex to LLVM IR. The
// use of complex values in function signatures is handled before conversion
// to LLVM IR dialect here.
//
// fir.complex<T> | std.complex<T> --> llvm<"{t,t}">
template <typename C>
mlir::Type convertComplexType(C cmplx) {
template <typename C> mlir::Type convertComplexType(C cmplx) const {
LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n');
auto eleTy = cmplx.getElementType();
return convertType(specifics->complexMemoryType(eleTy));
}

template <typename A>
mlir::Type convertPointerLike(A &ty) {
template <typename A> mlir::Type convertPointerLike(A &ty) const {
mlir::Type eleTy = ty.getEleTy();
// A sequence type is a special case. A sequence of runtime size on its
// interior dimensions lowers to a memory reference. In that case, we
Expand Down Expand Up @@ -126,27 +124,27 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {

// convert a front-end kind value to either a std or LLVM IR dialect type
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
mlir::Type convertRealType(fir::KindTy kind);
mlir::Type convertRealType(fir::KindTy kind) const;

// fir.array<c ... :any> --> llvm<"[...[c x any]]">
mlir::Type convertSequenceType(SequenceType seq);
mlir::Type convertSequenceType(SequenceType seq) const;

// fir.tdesc<any> --> llvm<"i8*">
// TODO: For now use a void*, however pointer identity is not sufficient for
// the f18 object v. class distinction (F2003).
mlir::Type convertTypeDescType(mlir::MLIRContext *ctx);
mlir::Type convertTypeDescType(mlir::MLIRContext *ctx) const;

KindMapping &getKindMap() { return kindMapping; }
const KindMapping &getKindMap() const { return kindMapping; }

// Relay TBAA tag attachment to TBAABuilder.
void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
mlir::Type baseFIRType, mlir::Type accessFIRType,
mlir::LLVM::GEPOp gep);
mlir::LLVM::GEPOp gep) const;

private:
KindMapping kindMapping;
std::unique_ptr<CodeGenSpecifics> specifics;
TBAABuilder tbaaBuilder;
std::unique_ptr<TBAABuilder> tbaaBuilder;
};

} // namespace fir
Expand Down
15 changes: 8 additions & 7 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ namespace {
template <typename FromOp>
class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
public:
explicit FIROpConversion(fir::LLVMTypeConverter &lowering,
explicit FIROpConversion(const fir::LLVMTypeConverter &lowering,
const fir::FIRToLLVMPassOptions &options)
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering), options(options) {}

Expand Down Expand Up @@ -359,8 +359,9 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
return al;
}

fir::LLVMTypeConverter &lowerTy() const {
return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
const fir::LLVMTypeConverter &lowerTy() const {
return *static_cast<const fir::LLVMTypeConverter *>(
this->getTypeConverter());
}

void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
Expand Down Expand Up @@ -3191,8 +3192,8 @@ struct SelectCaseOpConversion : public FIROpConversion<fir::SelectCaseOp> {
};

template <typename OP>
static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
OP select, typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
Expand Down Expand Up @@ -3461,7 +3462,7 @@ template <typename LLVMOP, typename OPTY>
static mlir::LLVM::InsertValueOp
complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
fir::LLVMTypeConverter &lowering) {
const fir::LLVMTypeConverter &lowering) {
mlir::Value a = opnds[0];
mlir::Value b = opnds[1];
auto loc = sumop.getLoc();
Expand Down Expand Up @@ -3610,7 +3611,7 @@ struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
/// These operations are normally dead after the pre-codegen pass.
template <typename FromOp>
struct MustBeDeadConversion : public FIROpConversion<FromOp> {
explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering,
explicit MustBeDeadConversion(const fir::LLVMTypeConverter &lowering,
const fir::FIRToLLVMPassOptions &options)
: FIROpConversion<FromOp>(lowering, options) {}
using OpAdaptor = typename FromOp::Adaptor;
Expand Down
37 changes: 19 additions & 18 deletions flang/lib/Optimizer/CodeGen/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
specifics(CodeGenSpecifics::get(module.getContext(),
getTargetTriple(module),
getKindMapping(module))),
tbaaBuilder(module->getContext(), applyTBAA) {
tbaaBuilder(
std::make_unique<TBAABuilder>(module->getContext(), applyTBAA)) {
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");

// Each conversion should return a value of type mlir::Type.
Expand Down Expand Up @@ -155,20 +156,19 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)

// i32 is used here because LLVM wants i32 constants when indexing into struct
// types. Indexing into other aggregate types is more flexible.
mlir::Type LLVMTypeConverter::offsetType() {
mlir::Type LLVMTypeConverter::offsetType() const {
return mlir::IntegerType::get(&getContext(), 32);
}

// i64 can be used to index into aggregates like arrays
mlir::Type LLVMTypeConverter::indexType() {
mlir::Type LLVMTypeConverter::indexType() const {
return mlir::IntegerType::get(&getContext(), 64);
}

// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult>
LLVMTypeConverter::convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack) {
std::optional<mlir::LogicalResult> LLVMTypeConverter::convertRecordType(
fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack) const {
auto name = derived.getName();
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
if (llvm::count(callStack, derived) > 1) {
Expand All @@ -192,14 +192,14 @@ LLVMTypeConverter::convertRecordType(fir::RecordType derived,

// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) {
bool LLVMTypeConverter::requiresExtendedDesc(mlir::Type boxElementType) const {
auto eleTy = fir::unwrapSequenceType(boxElementType);
return eleTy.isa<fir::RecordType>();
}

// This corresponds to the descriptor as defined in ISO_Fortran_binding.h and
// the addendum defined in descriptor.h.
mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) {
mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const {
// (base_addr*, elem_len, version, rank, type, attribute, f18Addendum, [dim]
llvm::SmallVector<mlir::Type> dataDescFields;
mlir::Type ele = box.getEleTy();
Expand Down Expand Up @@ -269,14 +269,14 @@ mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) {

/// Convert fir.box type to the corresponding llvm struct type instead of a
/// pointer to this struct type.
mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) {
mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box) const {
return convertBoxType(box)
.cast<mlir::LLVM::LLVMPointerType>()
.getElementType();
}

// fir.boxproc<any> --> llvm<"{ any*, i8* }">
mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) {
mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const {
auto funcTy = convertType(boxproc.getEleTy());
auto i8PtrTy = mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(&getContext(), 8));
Expand All @@ -285,13 +285,13 @@ mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) {
/*isPacked=*/false);
}

unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) {
unsigned LLVMTypeConverter::characterBitsize(fir::CharacterType charTy) const {
return kindMapping.getCharacterBitsize(charTy.getFKind());
}

// fir.char<k,?> --> llvm<"ix"> where ix is scaled by kind mapping
// fir.char<k,n> --> llvm.array<n x "ix">
mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) {
mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) const {
auto iTy = mlir::IntegerType::get(&getContext(), characterBitsize(charTy));
if (charTy.getLen() == fir::CharacterType::unknownLen())
return iTy;
Expand All @@ -300,13 +300,13 @@ mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) {

// convert a front-end kind value to either a std or LLVM IR dialect type
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) {
mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) const {
return fir::fromRealTypeID(&getContext(), kindMapping.getRealTypeID(kind),
kind);
}

// fir.array<c ... :any> --> llvm<"[...[c x any]]">
mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) {
mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const {
auto baseTy = convertType(seq.getEleTy());
if (characterWithDynamicLen(seq.getEleTy()))
return mlir::LLVM::LLVMPointerType::get(baseTy);
Expand All @@ -328,7 +328,8 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) {
// fir.tdesc<any> --> llvm<"i8*">
// TODO: For now use a void*, however pointer identity is not sufficient for
// the f18 object v. class distinction (F2003).
mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) {
mlir::Type
LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(&getContext(), 8));
}
Expand All @@ -337,8 +338,8 @@ mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) {
void LLVMTypeConverter::attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
mlir::Type baseFIRType,
mlir::Type accessFIRType,
mlir::LLVM::GEPOp gep) {
tbaaBuilder.attachTBAATag(op, baseFIRType, accessFIRType, gep);
mlir::LLVM::GEPOp gep) const {
tbaaBuilder->attachTBAATag(op, baseFIRType, accessFIRType, gep);
}

} // namespace fir
Loading

0 comments on commit ce25459

Please sign in to comment.