From de10d03e38a5500edf4e00a09f2a6caa47e60a39 Mon Sep 17 00:00:00 2001 From: Chuanqi Xu Date: Mon, 11 Nov 2024 17:59:29 +0800 Subject: [PATCH] [CIR] [Lowering] [NFC] Split LowerToLLVM.h from LowerToLLVM.cpp --- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 5879 ++++++++--------- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 1024 +++ clang/utils/TableGen/CIRLoweringEmitter.cpp | 33 +- 3 files changed, 3712 insertions(+), 3224 deletions(-) create mode 100644 clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 97ea6a588853..e0a87dd78413 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -9,19 +9,15 @@ // This file implements lowering of CIR operations to LLVMIR. // //===----------------------------------------------------------------------===// +#include "LowerToLLVM.h" #include "LoweringHelpers.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -45,11 +41,6 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" -#include "mlir/Transforms/DialectConversion.h" -#include "clang/CIR/Dialect/IR/CIRAttrs.h" -#include "clang/CIR/Dialect/IR/CIRDialect.h" -#include "clang/CIR/Dialect/IR/CIROpsEnums.h" -#include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/LoweringHelpers.h" #include "clang/CIR/MissingFeatures.h" @@ -70,8 +61,6 @@ #include #include -#include "LowerModule.h" - using namespace cir; using namespace llvm; @@ -367,13 +356,9 @@ unsigned getGlobalOpTargetAddrSpace(mlir::ConversionPatternRewriter &rewriter, //===----------------------------------------------------------------------===// /// Switches on the type of attribute and calls the appropriate conversion. -inline mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter); /// IntAttr visitor. -inline mlir::Value +static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::IntAttr intAttr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) { @@ -383,7 +368,7 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::IntAttr intAttr, } /// BoolAttr visitor. -inline mlir::Value +static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::BoolAttr boolAttr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) { @@ -393,7 +378,7 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::BoolAttr boolAttr, } /// ConstPtrAttr visitor. -inline mlir::Value +static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) { @@ -411,7 +396,7 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr, } /// FPAttr visitor. -inline mlir::Value +static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::FPAttr fltAttr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) { @@ -421,7 +406,7 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::FPAttr fltAttr, } /// ZeroAttr visitor. -inline mlir::Value +static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ZeroAttr zeroAttr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) { @@ -431,7 +416,7 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ZeroAttr zeroAttr, } /// UndefAttr visitor. -inline mlir::Value +static mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) { @@ -441,10 +426,10 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr, } /// ConstStruct visitor. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, - cir::ConstStructAttr constStruct, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +static mlir::Value +lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { auto llvmTy = converter->convertType(constStruct.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); @@ -467,10 +452,10 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } // VTableAttr visitor. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, - cir::VTableAttr vtableArr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +static mlir::Value +lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { auto llvmTy = converter->convertType(vtableArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); @@ -484,10 +469,10 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } // TypeInfoAttr visitor. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, - cir::TypeInfoAttr typeinfoArr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +static mlir::Value +lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { auto llvmTy = converter->convertType(typeinfoArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result = rewriter.create(loc, llvmTy); @@ -501,10 +486,10 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } // ConstArrayAttr visitor -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, - cir::ConstArrayAttr constArr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +static mlir::Value +lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { auto llvmTy = converter->convertType(constArr.getType()); auto loc = parentOp->getLoc(); mlir::Value result; @@ -547,10 +532,10 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } // ConstVectorAttr visitor. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, - cir::ConstVectorAttr constVec, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +static mlir::Value +lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { auto llvmTy = converter->convertType(constVec.getType()); auto loc = parentOp->getLoc(); SmallVector mlirValues; @@ -575,10 +560,10 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } // GlobalViewAttr visitor. -mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, - cir::GlobalViewAttr globalAttr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +static mlir::Value +lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { auto module = parentOp->getParentOfType(); mlir::Type sourceType; unsigned sourceAddrSpace = 0; @@ -646,10 +631,9 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } /// Switches on the type of attribute and calls the appropriate conversion. -inline mlir::Value -lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, - mlir::ConversionPatternRewriter &rewriter, - const mlir::TypeConverter *converter) { +mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) { if (const auto intAttr = mlir::dyn_cast(attr)) return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter); if (const auto fltAttr = mlir::dyn_cast(attr)) @@ -723,105 +707,75 @@ mlir::LLVM::CConv convertCallingConv(cir::CallingConv callinvConv) { llvm_unreachable("Unknown calling convention"); } -class CIRCopyOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::CopyOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - const mlir::Value length = rewriter.create( - op.getLoc(), rewriter.getI32Type(), op.getLength()); - rewriter.replaceOpWithNewOp( - op, adaptor.getDst(), adaptor.getSrc(), length, op.getIsVolatile()); - return mlir::success(); - } -}; - -class CIRMemCpyOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::MemCpyOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLen(), - /*isVolatile=*/false); - return mlir::success(); - } -}; - -class CIRMemChrOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMCopyOpLowering::matchAndRewrite( + cir::CopyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const mlir::Value length = rewriter.create( + op.getLoc(), rewriter.getI32Type(), op.getLength()); + rewriter.replaceOpWithNewOp( + op, adaptor.getDst(), adaptor.getSrc(), length, op.getIsVolatile()); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::MemChrOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - llvm::SmallVector arguments; - const mlir::TypeConverter *converter = getTypeConverter(); - mlir::Type srcTy = converter->convertType(op.getSrc().getType()); - mlir::Type patternTy = converter->convertType(op.getPattern().getType()); - mlir::Type lenTy = converter->convertType(op.getLen().getType()); - auto fnTy = - mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {srcTy, patternTy, lenTy}, - /*isVarArg=*/false); - llvm::StringRef fnName = "memchr"; - getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); - rewriter.replaceOpWithNewOp( - op, mlir::TypeRange{llvmPtrTy}, fnName, - mlir::ValueRange{adaptor.getSrc(), adaptor.getPattern(), - adaptor.getLen()}); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMMemCpyOpLowering::matchAndRewrite( + cir::MemCpyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLen(), + /*isVolatile=*/false); + return mlir::success(); +} -class CIRMemCpyInlineOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::MemCpyInlineOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLenAttr(), - /*isVolatile=*/false); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMMemChrOpLowering::matchAndRewrite( + cir::MemChrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + llvm::SmallVector arguments; + const mlir::TypeConverter *converter = getTypeConverter(); + mlir::Type srcTy = converter->convertType(op.getSrc().getType()); + mlir::Type patternTy = converter->convertType(op.getPattern().getType()); + mlir::Type lenTy = converter->convertType(op.getLen().getType()); + auto fnTy = + mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {srcTy, patternTy, lenTy}, + /*isVarArg=*/false); + llvm::StringRef fnName = "memchr"; + getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); + rewriter.replaceOpWithNewOp( + op, mlir::TypeRange{llvmPtrTy}, fnName, + mlir::ValueRange{adaptor.getSrc(), adaptor.getPattern(), + adaptor.getLen()}); + return mlir::success(); +} -class CIRMemMoveOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMMemMoveOpLowering::matchAndRewrite( + cir::MemMoveOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLen(), + /*isVolatile=*/false); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::MemMoveOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLen(), - /*isVolatile=*/false); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMMemCpyInlineOpLowering::matchAndRewrite( + cir::MemCpyInlineOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLenAttr(), + /*isVolatile=*/false); + return mlir::success(); +} -class CIRMemsetOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(cir::MemSetOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto converted = rewriter.create( - op.getLoc(), mlir::IntegerType::get(op.getContext(), 8), - adaptor.getVal()); - rewriter.replaceOpWithNewOp( - op, adaptor.getDst(), converted, adaptor.getLen(), - /*isVolatile=*/false); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMMemSetOpLowering::matchAndRewrite( + cir::MemSetOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto converted = rewriter.create( + op.getLoc(), mlir::IntegerType::get(op.getContext(), 8), + adaptor.getVal()); + rewriter.replaceOpWithNewOp(op, adaptor.getDst(), + converted, adaptor.getLen(), + /*isVolatile=*/false); + return mlir::success(); +} static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter, mlir::Value llvmSrc, mlir::Type llvmDstIntTy, @@ -841,138 +795,120 @@ static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter, return rewriter.create(loc, llvmDstIntTy, llvmSrc); } -class CIRPtrStrideOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto *tc = getTypeConverter(); - const auto resultTy = tc->convertType(ptrStrideOp.getType()); - auto elementTy = tc->convertType(ptrStrideOp.getElementTy()); - auto *ctx = elementTy.getContext(); - - // void and function types doesn't really have a layout to use in GEPs, - // make it i8 instead. - if (mlir::isa(elementTy) || - mlir::isa(elementTy)) - elementTy = mlir::IntegerType::get(elementTy.getContext(), 8, - mlir::IntegerType::Signless); - - // Zero-extend, sign-extend or trunc the pointer value. - auto index = adaptor.getStride(); - auto width = mlir::cast(index.getType()).getWidth(); - mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType()); - auto layoutWidth = - LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType()); - auto indexOp = index.getDefiningOp(); - if (indexOp && layoutWidth && width != *layoutWidth) { - // If the index comes from a subtraction, make sure the extension happens - // before it. To achieve that, look at unary minus, which already got - // lowered to "sub 0, x". - auto sub = dyn_cast(indexOp); - auto unary = dyn_cast_if_present( - ptrStrideOp.getStride().getDefiningOp()); - bool rewriteSub = - unary && unary.getKind() == cir::UnaryOpKind::Minus && sub; - if (rewriteSub) - index = indexOp->getOperand(1); - - // Handle the cast - auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth); - index = getLLVMIntCast(rewriter, index, llvmDstType, - ptrStrideOp.getStride().getType().isUnsigned(), - width, *layoutWidth); - - // Rewrite the sub in front of extensions/trunc - if (rewriteSub) { - index = rewriter.create( - index.getLoc(), index.getType(), - rewriter.create( - index.getLoc(), index.getType(), - mlir::IntegerAttr::get(index.getType(), 0)), - index); - rewriter.eraseOp(sub); - } - } +mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( + cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto *tc = getTypeConverter(); + const auto resultTy = tc->convertType(ptrStrideOp.getType()); + auto elementTy = tc->convertType(ptrStrideOp.getElementTy()); + auto *ctx = elementTy.getContext(); + + // void and function types doesn't really have a layout to use in GEPs, + // make it i8 instead. + if (mlir::isa(elementTy) || + mlir::isa(elementTy)) + elementTy = mlir::IntegerType::get(elementTy.getContext(), 8, + mlir::IntegerType::Signless); - rewriter.replaceOpWithNewOp( - ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index); - return mlir::success(); + // Zero-extend, sign-extend or trunc the pointer value. + auto index = adaptor.getStride(); + auto width = mlir::cast(index.getType()).getWidth(); + mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType()); + auto layoutWidth = + LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType()); + auto indexOp = index.getDefiningOp(); + if (indexOp && layoutWidth && width != *layoutWidth) { + // If the index comes from a subtraction, make sure the extension happens + // before it. To achieve that, look at unary minus, which already got + // lowered to "sub 0, x". + auto sub = dyn_cast(indexOp); + auto unary = dyn_cast_if_present( + ptrStrideOp.getStride().getDefiningOp()); + bool rewriteSub = + unary && unary.getKind() == cir::UnaryOpKind::Minus && sub; + if (rewriteSub) + index = indexOp->getOperand(1); + + // Handle the cast + auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth); + index = getLLVMIntCast(rewriter, index, llvmDstType, + ptrStrideOp.getStride().getType().isUnsigned(), + width, *layoutWidth); + + // Rewrite the sub in front of extensions/trunc + if (rewriteSub) { + index = rewriter.create( + index.getLoc(), index.getType(), + rewriter.create( + index.getLoc(), index.getType(), + mlir::IntegerAttr::get(index.getType(), 0)), + index); + rewriter.eraseOp(sub); + } } -}; -class CIRBaseClassAddrOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - const auto resultType = - getTypeConverter()->convertType(baseClassOp.getType()); - mlir::Value derivedAddr = adaptor.getDerivedAddr(); - llvm::SmallVector offset = { - adaptor.getOffset().getZExtValue()}; - mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8, - mlir::IntegerType::Signless); - if (adaptor.getOffset().getZExtValue() == 0) { - rewriter.replaceOpWithNewOp( - baseClassOp, resultType, adaptor.getDerivedAddr()); - return mlir::success(); - } + rewriter.replaceOpWithNewOp( + ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index); + return mlir::success(); +} - if (baseClassOp.getAssumeNotNull()) { - rewriter.replaceOpWithNewOp( - baseClassOp, resultType, byteType, derivedAddr, offset); - } else { - auto loc = baseClassOp.getLoc(); - mlir::Value isNull = rewriter.create( - loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr, - rewriter.create(loc, derivedAddr.getType())); - mlir::Value adjusted = rewriter.create( - loc, resultType, byteType, derivedAddr, offset); - rewriter.replaceOpWithNewOp(baseClassOp, isNull, - derivedAddr, adjusted); - } +mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite( + cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const auto resultType = + getTypeConverter()->convertType(baseClassOp.getType()); + mlir::Value derivedAddr = adaptor.getDerivedAddr(); + llvm::SmallVector offset = { + adaptor.getOffset().getZExtValue()}; + mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8, + mlir::IntegerType::Signless); + if (adaptor.getOffset().getZExtValue() == 0) { + rewriter.replaceOpWithNewOp( + baseClassOp, resultType, adaptor.getDerivedAddr()); return mlir::success(); } -}; -class CIRDerivedClassAddrOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::DerivedClassAddrOp derivedClassOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - const auto resultType = - getTypeConverter()->convertType(derivedClassOp.getType()); - mlir::Value baseAddr = adaptor.getBaseAddr(); - int64_t offsetVal = adaptor.getOffset().getZExtValue() * -1; - llvm::SmallVector offset = {offsetVal}; - mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8, - mlir::IntegerType::Signless); - if (derivedClassOp.getAssumeNotNull()) { - rewriter.replaceOpWithNewOp( - derivedClassOp, resultType, byteType, baseAddr, offset); - } else { - auto loc = derivedClassOp.getLoc(); - mlir::Value isNull = rewriter.create( - loc, mlir::LLVM::ICmpPredicate::eq, baseAddr, - rewriter.create(loc, baseAddr.getType())); - mlir::Value adjusted = rewriter.create( - loc, resultType, byteType, baseAddr, offset); - rewriter.replaceOpWithNewOp(derivedClassOp, isNull, - baseAddr, adjusted); - } - return mlir::success(); + if (baseClassOp.getAssumeNotNull()) { + rewriter.replaceOpWithNewOp( + baseClassOp, resultType, byteType, derivedAddr, offset); + } else { + auto loc = baseClassOp.getLoc(); + mlir::Value isNull = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr, + rewriter.create(loc, derivedAddr.getType())); + mlir::Value adjusted = rewriter.create( + loc, resultType, byteType, derivedAddr, offset); + rewriter.replaceOpWithNewOp(baseClassOp, isNull, + derivedAddr, adjusted); } -}; + return mlir::success(); +} + +mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite( + cir::DerivedClassAddrOp derivedClassOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const auto resultType = + getTypeConverter()->convertType(derivedClassOp.getType()); + mlir::Value baseAddr = adaptor.getBaseAddr(); + int64_t offsetVal = adaptor.getOffset().getZExtValue() * -1; + llvm::SmallVector offset = {offsetVal}; + mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8, + mlir::IntegerType::Signless); + if (derivedClassOp.getAssumeNotNull()) { + rewriter.replaceOpWithNewOp(derivedClassOp, resultType, + byteType, baseAddr, offset); + } else { + auto loc = derivedClassOp.getLoc(); + mlir::Value isNull = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, baseAddr, + rewriter.create(loc, baseAddr.getType())); + mlir::Value adjusted = rewriter.create( + loc, resultType, byteType, baseAddr, offset); + rewriter.replaceOpWithNewOp(derivedClassOp, isNull, + baseAddr, adjusted); + } + return mlir::success(); +} static mlir::Value getValueForVTableSymbol(mlir::Operation *op, @@ -991,290 +927,265 @@ getValueForVTableSymbol(mlir::Operation *op, nameAttr.getValue()); } -class CIRVTTAddrPointOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VTTAddrPointOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - const mlir::Type resultType = getTypeConverter()->convertType(op.getType()); - llvm::SmallVector offsets; - mlir::Type eltType; - mlir::Value llvmAddr = adaptor.getSymAddr(); - - if (op.getSymAddr()) { - if (op.getOffset() == 0) { - rewriter.replaceAllUsesWith(op, llvmAddr); - rewriter.eraseOp(op); - return mlir::success(); - } - - offsets.push_back(adaptor.getOffset()); - eltType = mlir::IntegerType::get(resultType.getContext(), 8, - mlir::IntegerType::Signless); - } else { - llvmAddr = getValueForVTableSymbol(op, rewriter, getTypeConverter(), - op.getNameAttr(), eltType); - assert(eltType && "Shouldn't ever be missing an eltType here"); - offsets.push_back(0); - offsets.push_back(adaptor.getOffset()); +mlir::LogicalResult CIRToLLVMVTTAddrPointOpLowering::matchAndRewrite( + cir::VTTAddrPointOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const mlir::Type resultType = getTypeConverter()->convertType(op.getType()); + llvm::SmallVector offsets; + mlir::Type eltType; + mlir::Value llvmAddr = adaptor.getSymAddr(); + + if (op.getSymAddr()) { + if (op.getOffset() == 0) { + rewriter.replaceAllUsesWith(op, llvmAddr); + rewriter.eraseOp(op); + return mlir::success(); } - rewriter.replaceOpWithNewOp(op, resultType, eltType, - llvmAddr, offsets, true); - return mlir::success(); - } -}; -class CIRBrCondOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; + offsets.push_back(adaptor.getOffset()); + eltType = mlir::IntegerType::get(resultType.getContext(), 8, + mlir::IntegerType::Signless); + } else { + llvmAddr = getValueForVTableSymbol(op, rewriter, getTypeConverter(), + op.getNameAttr(), eltType); + assert(eltType && "Shouldn't ever be missing an eltType here"); + offsets.push_back(0); + offsets.push_back(adaptor.getOffset()); + } + rewriter.replaceOpWithNewOp(op, resultType, eltType, + llvmAddr, offsets, true); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::BrCondOp brOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Value i1Condition; +mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite( + cir::BrCondOp brOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Value i1Condition; - auto hasOneUse = false; + auto hasOneUse = false; - if (auto defOp = brOp.getCond().getDefiningOp()) - hasOneUse = defOp->getResult(0).hasOneUse(); + if (auto defOp = brOp.getCond().getDefiningOp()) + hasOneUse = defOp->getResult(0).hasOneUse(); - if (auto defOp = adaptor.getCond().getDefiningOp()) { - if (auto zext = dyn_cast(defOp)) { - if (zext->use_empty() && - zext->getOperand(0).getType() == rewriter.getI1Type()) { - i1Condition = zext->getOperand(0); - if (hasOneUse) - rewriter.eraseOp(zext); - } + if (auto defOp = adaptor.getCond().getDefiningOp()) { + if (auto zext = dyn_cast(defOp)) { + if (zext->use_empty() && + zext->getOperand(0).getType() == rewriter.getI1Type()) { + i1Condition = zext->getOperand(0); + if (hasOneUse) + rewriter.eraseOp(zext); } } + } - if (!i1Condition) - i1Condition = rewriter.create( - brOp.getLoc(), rewriter.getI1Type(), adaptor.getCond()); + if (!i1Condition) + i1Condition = rewriter.create( + brOp.getLoc(), rewriter.getI1Type(), adaptor.getCond()); - rewriter.replaceOpWithNewOp( - brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), - brOp.getDestFalse(), adaptor.getDestOperandsFalse()); + rewriter.replaceOpWithNewOp( + brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), + brOp.getDestFalse(), adaptor.getDestOperandsFalse()); - return mlir::success(); - } -}; + return mlir::success(); +} -class CIRCastOpLowering : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; +mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const { + return getTypeConverter()->convertType(ty); +} - inline mlir::Type convertTy(mlir::Type ty) const { - return getTypeConverter()->convertType(ty); - } +mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( + cir::CastOp castOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // For arithmetic conversions, LLVM IR uses the same instruction to convert + // both individual scalars and entire vectors. This lowering pass handles + // both situations. + + auto src = adaptor.getSrc(); + + switch (castOp.getKind()) { + case cir::CastKind::array_to_ptrdecay: { + const auto ptrTy = mlir::cast(castOp.getType()); + auto sourceValue = adaptor.getOperands().front(); + auto targetType = convertTy(ptrTy); + auto elementTy = convertTy(ptrTy.getPointee()); + auto offset = llvm::SmallVector{0}; + rewriter.replaceOpWithNewOp( + castOp, targetType, elementTy, sourceValue, offset); + break; + } + case cir::CastKind::int_to_bool: { + auto zero = rewriter.create( + src.getLoc(), castOp.getSrc().getType(), + cir::IntAttr::get(castOp.getSrc().getType(), 0)); + rewriter.replaceOpWithNewOp( + castOp, cir::BoolType::get(getContext()), cir::CmpOpKind::ne, + castOp.getSrc(), zero); + break; + } + case cir::CastKind::integral: { + auto srcType = castOp.getSrc().getType(); + auto dstType = castOp.getResult().getType(); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstType = getTypeConverter()->convertType(dstType); + cir::IntType srcIntType = + mlir::cast(elementTypeIfVector(srcType)); + cir::IntType dstIntType = + mlir::cast(elementTypeIfVector(dstType)); + rewriter.replaceOp(castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType, + srcIntType.isUnsigned(), + srcIntType.getWidth(), + dstIntType.getWidth())); + break; + } + case cir::CastKind::floating: { + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = + getTypeConverter()->convertType(castOp.getResult().getType()); + + auto srcTy = elementTypeIfVector(castOp.getSrc().getType()); + auto dstTy = elementTypeIfVector(castOp.getResult().getType()); + + if (!mlir::isa(dstTy) || + !mlir::isa(srcTy)) + return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy; + + auto getFloatWidth = [](mlir::Type ty) -> unsigned { + return mlir::cast(ty).getWidth(); + }; - mlir::LogicalResult - matchAndRewrite(cir::CastOp castOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // For arithmetic conversions, LLVM IR uses the same instruction to convert - // both individual scalars and entire vectors. This lowering pass handles - // both situations. - - auto src = adaptor.getSrc(); - - switch (castOp.getKind()) { - case cir::CastKind::array_to_ptrdecay: { - const auto ptrTy = mlir::cast(castOp.getType()); - auto sourceValue = adaptor.getOperands().front(); - auto targetType = convertTy(ptrTy); - auto elementTy = convertTy(ptrTy.getPointee()); - auto offset = llvm::SmallVector{0}; - rewriter.replaceOpWithNewOp( - castOp, targetType, elementTy, sourceValue, offset); - break; - } - case cir::CastKind::int_to_bool: { - auto zero = rewriter.create( - src.getLoc(), castOp.getSrc().getType(), - cir::IntAttr::get(castOp.getSrc().getType(), 0)); - rewriter.replaceOpWithNewOp( - castOp, cir::BoolType::get(getContext()), cir::CmpOpKind::ne, - castOp.getSrc(), zero); - break; - } - case cir::CastKind::integral: { - auto srcType = castOp.getSrc().getType(); - auto dstType = castOp.getResult().getType(); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstType = getTypeConverter()->convertType(dstType); - cir::IntType srcIntType = - mlir::cast(elementTypeIfVector(srcType)); - cir::IntType dstIntType = - mlir::cast(elementTypeIfVector(dstType)); - rewriter.replaceOp( - castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType, - srcIntType.isUnsigned(), srcIntType.getWidth(), - dstIntType.getWidth())); - break; - } - case cir::CastKind::floating: { - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = - getTypeConverter()->convertType(castOp.getResult().getType()); - - auto srcTy = elementTypeIfVector(castOp.getSrc().getType()); - auto dstTy = elementTypeIfVector(castOp.getResult().getType()); - - if (!mlir::isa(dstTy) || - !mlir::isa(srcTy)) - return castOp.emitError() - << "NYI cast from " << srcTy << " to " << dstTy; - - auto getFloatWidth = [](mlir::Type ty) -> unsigned { - return mlir::cast(ty).getWidth(); - }; - - if (getFloatWidth(srcTy) > getFloatWidth(dstTy)) - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - else - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + if (getFloatWidth(srcTy) > getFloatWidth(dstTy)) + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::int_to_ptr: { - auto dstTy = mlir::cast(castOp.getType()); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::ptr_to_int: { - auto dstTy = mlir::cast(castOp.getType()); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::float_to_bool: { - auto dstTy = mlir::cast(castOp.getType()); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - auto kind = mlir::LLVM::FCmpPredicate::une; - - // Check if float is not equal to zero. - auto zeroFloat = rewriter.create( - castOp.getLoc(), llvmSrcVal.getType(), - mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); - - // Extend comparison result to either bool (C++) or int (C). - mlir::Value cmpResult = rewriter.create( - castOp.getLoc(), kind, llvmSrcVal, zeroFloat); - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - cmpResult); - return mlir::success(); - } - case cir::CastKind::bool_to_int: { - auto dstTy = mlir::cast(castOp.getType()); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmSrcTy = mlir::cast(llvmSrcVal.getType()); - auto llvmDstTy = - mlir::cast(getTypeConverter()->convertType(dstTy)); - if (llvmSrcTy.getWidth() == llvmDstTy.getWidth()) - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - else - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + else + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + return mlir::success(); + } + case cir::CastKind::int_to_ptr: { + auto dstTy = mlir::cast(castOp.getType()); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::bool_to_float: { - auto dstTy = castOp.getType(); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + return mlir::success(); + } + case cir::CastKind::ptr_to_int: { + auto dstTy = mlir::cast(castOp.getType()); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::int_to_float: { - auto dstTy = castOp.getType(); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - if (mlir::cast( - elementTypeIfVector(castOp.getSrc().getType())) - .isSigned()) - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - else - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::float_to_int: { - auto dstTy = castOp.getType(); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - if (mlir::cast( - elementTypeIfVector(castOp.getResult().getType())) - .isSigned()) - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - else - rewriter.replaceOpWithNewOp(castOp, llvmDstTy, - llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::bitcast: { - auto dstTy = castOp.getType(); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); + return mlir::success(); + } + case cir::CastKind::float_to_bool: { + auto dstTy = mlir::cast(castOp.getType()); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + auto kind = mlir::LLVM::FCmpPredicate::une; + + // Check if float is not equal to zero. + auto zeroFloat = rewriter.create( + castOp.getLoc(), llvmSrcVal.getType(), + mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); + + // Extend comparison result to either bool (C++) or int (C). + mlir::Value cmpResult = rewriter.create( + castOp.getLoc(), kind, llvmSrcVal, zeroFloat); + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + cmpResult); + return mlir::success(); + } + case cir::CastKind::bool_to_int: { + auto dstTy = mlir::cast(castOp.getType()); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmSrcTy = mlir::cast(llvmSrcVal.getType()); + auto llvmDstTy = + mlir::cast(getTypeConverter()->convertType(dstTy)); + if (llvmSrcTy.getWidth() == llvmDstTy.getWidth()) rewriter.replaceOpWithNewOp(castOp, llvmDstTy, llvmSrcVal); - return mlir::success(); - } - case cir::CastKind::ptr_to_bool: { - auto zero = - mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), 0); - auto null = rewriter.create( - src.getLoc(), castOp.getSrc().getType(), - cir::ConstPtrAttr::get(getContext(), castOp.getSrc().getType(), - zero)); - rewriter.replaceOpWithNewOp( - castOp, cir::BoolType::get(getContext()), cir::CmpOpKind::ne, - castOp.getSrc(), null); - break; - } - case cir::CastKind::address_space: { - auto dstTy = castOp.getType(); - auto llvmSrcVal = adaptor.getOperands().front(); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - rewriter.replaceOpWithNewOp( - castOp, llvmDstTy, llvmSrcVal); - break; - } - default: { - return castOp.emitError("Unhandled cast kind: ") - << castOp.getKindAttrName(); - } - } - + else + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); return mlir::success(); } -}; + case cir::CastKind::bool_to_float: { + auto dstTy = castOp.getType(); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + return mlir::success(); + } + case cir::CastKind::int_to_float: { + auto dstTy = castOp.getType(); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + if (mlir::cast(elementTypeIfVector(castOp.getSrc().getType())) + .isSigned()) + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + else + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + return mlir::success(); + } + case cir::CastKind::float_to_int: { + auto dstTy = castOp.getType(); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + if (mlir::cast( + elementTypeIfVector(castOp.getResult().getType())) + .isSigned()) + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + else + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + return mlir::success(); + } + case cir::CastKind::bitcast: { + auto dstTy = castOp.getType(); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + return mlir::success(); + } + case cir::CastKind::ptr_to_bool: { + auto zero = + mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), 0); + auto null = rewriter.create( + src.getLoc(), castOp.getSrc().getType(), + cir::ConstPtrAttr::get(getContext(), castOp.getSrc().getType(), zero)); + rewriter.replaceOpWithNewOp( + castOp, cir::BoolType::get(getContext()), cir::CmpOpKind::ne, + castOp.getSrc(), null); + break; + } + case cir::CastKind::address_space: { + auto dstTy = castOp.getType(); + auto llvmSrcVal = adaptor.getOperands().front(); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); + rewriter.replaceOpWithNewOp(castOp, llvmDstTy, + llvmSrcVal); + break; + } + default: { + return castOp.emitError("Unhandled cast kind: ") + << castOp.getKindAttrName(); + } + } -class CIRReturnLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::ReturnOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, - adaptor.getOperands()); - return mlir::LogicalResult::success(); - } -}; +mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite( + cir::ReturnOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return mlir::LogicalResult::success(); +} struct ConvertCIRToLLVMPass : public mlir::PassWrapper { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::CallOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - return rewriteToCallOrInvoke(op.getOperation(), adaptor.getOperands(), - rewriter, getTypeConverter(), - op.getCalleeAttr()); - } -}; - -class CIRTryCallLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite( + cir::CallOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + return rewriteToCallOrInvoke(op.getOperation(), adaptor.getOperands(), + rewriter, getTypeConverter(), + op.getCalleeAttr()); +} - mlir::LogicalResult - matchAndRewrite(cir::TryCallOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - if (op.getCallingConv() != cir::CallingConv::C) { - return op.emitError( - "non-C calling convention is not implemented for try_call"); - } - return rewriteToCallOrInvoke( - op.getOperation(), adaptor.getOperands(), rewriter, getTypeConverter(), - op.getCalleeAttr(), op.getCont(), op.getLandingPad()); +mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite( + cir::TryCallOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + if (op.getCallingConv() != cir::CallingConv::C) { + return op.emitError( + "non-C calling convention is not implemented for try_call"); } -}; + return rewriteToCallOrInvoke(op.getOperation(), adaptor.getOperands(), + rewriter, getTypeConverter(), op.getCalleeAttr(), + op.getCont(), op.getLandingPad()); +} static mlir::LLVM::LLVMStructType getLLVMLandingPadStructTy(mlir::ConversionPatternRewriter &rewriter) { @@ -1385,184 +1286,154 @@ getLLVMLandingPadStructTy(mlir::ConversionPatternRewriter &rewriter) { return mlir::LLVM::LLVMStructType::getLiteral(ctx, structFields); } -class CIREhInflightOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::EhInflightOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Location loc = op.getLoc(); - auto llvmLandingPadStructTy = getLLVMLandingPadStructTy(rewriter); - mlir::ArrayAttr symListAttr = op.getSymTypeListAttr(); - mlir::SmallVector symAddrs; - - auto llvmFn = op->getParentOfType(); - assert(llvmFn && "expected LLVM function parent"); - mlir::Block *entryBlock = &llvmFn.getRegion().front(); - assert(entryBlock->isEntryBlock()); - - // %x = landingpad { ptr, i32 } - // Note that since llvm.landingpad has to be the first operation on the - // block, any needed value for its operands has to be added somewhere else. - if (symListAttr) { - // catch ptr @_ZTIi - // catch ptr @_ZTIPKc - for (mlir::Attribute attr : op.getSymTypeListAttr()) { - auto symAttr = cast(attr); - // Generate `llvm.mlir.addressof` for each symbol, and place those - // operations in the LLVM function entry basic block. - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(entryBlock); - mlir::Value addrOp = rewriter.create( - loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), - symAttr.getValue()); - symAddrs.push_back(addrOp); - } - } else { - if (!op.getCleanup()) { - // catch ptr null - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(entryBlock); - mlir::Value nullOp = rewriter.create( - loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext())); - symAddrs.push_back(nullOp); - } +mlir::LogicalResult CIRToLLVMEhInflightOpLowering::matchAndRewrite( + cir::EhInflightOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Location loc = op.getLoc(); + auto llvmLandingPadStructTy = getLLVMLandingPadStructTy(rewriter); + mlir::ArrayAttr symListAttr = op.getSymTypeListAttr(); + mlir::SmallVector symAddrs; + + auto llvmFn = op->getParentOfType(); + assert(llvmFn && "expected LLVM function parent"); + mlir::Block *entryBlock = &llvmFn.getRegion().front(); + assert(entryBlock->isEntryBlock()); + + // %x = landingpad { ptr, i32 } + // Note that since llvm.landingpad has to be the first operation on the + // block, any needed value for its operands has to be added somewhere else. + if (symListAttr) { + // catch ptr @_ZTIi + // catch ptr @_ZTIPKc + for (mlir::Attribute attr : op.getSymTypeListAttr()) { + auto symAttr = cast(attr); + // Generate `llvm.mlir.addressof` for each symbol, and place those + // operations in the LLVM function entry basic block. + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + mlir::Value addrOp = rewriter.create( + loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + symAttr.getValue()); + symAddrs.push_back(addrOp); } - - // %slot = extractvalue { ptr, i32 } %x, 0 - // %selector = extractvalue { ptr, i32 } %x, 1 - auto padOp = rewriter.create( - loc, llvmLandingPadStructTy, symAddrs); - SmallVector slotIdx = {0}; - SmallVector selectorIdx = {1}; - - if (op.getCleanup()) - padOp.setCleanup(true); - - mlir::Value slot = - rewriter.create(loc, padOp, slotIdx); - mlir::Value selector = - rewriter.create(loc, padOp, selectorIdx); - - rewriter.replaceOp(op, mlir::ValueRange{slot, selector}); - - // Landing pads are required to be in LLVM functions with personality - // attribute. FIXME: for now hardcode personality creation in order to start - // adding exception tests, once we annotate CIR with such information, - // change it to be in FuncOp lowering instead. - { + } else { + if (!op.getCleanup()) { + // catch ptr null mlir::OpBuilder::InsertionGuard guard(rewriter); - // Insert personality decl before the current function. - rewriter.setInsertionPoint(llvmFn); - auto personalityFnTy = - mlir::LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {}, - /*isVarArg=*/true); - // Get or create `__gxx_personality_v0` - StringRef fnName = "__gxx_personality_v0"; - getOrCreateLLVMFuncOp(rewriter, op, fnName, personalityFnTy); - llvmFn.setPersonality(fnName); + rewriter.setInsertionPointToStart(entryBlock); + mlir::Value nullOp = rewriter.create( + loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext())); + symAddrs.push_back(nullOp); } - return mlir::success(); } -}; -class CIRAllocaLowering : public mlir::OpConversionPattern { - mlir::DataLayout const &dataLayout; - // Track globals created for annotation related strings - llvm::StringMap &stringGlobalsMap; - // Track globals created for annotation arg related strings. - // They are different from annotation strings, as strings used in args - // are not in llvmMetadataSectionName, and also has aligment 1. - llvm::StringMap &argStringGlobalsMap; - // Track globals created for annotation args. - llvm::MapVector &argsVarMap; + // %slot = extractvalue { ptr, i32 } %x, 0 + // %selector = extractvalue { ptr, i32 } %x, 1 + auto padOp = rewriter.create( + loc, llvmLandingPadStructTy, symAddrs); + SmallVector slotIdx = {0}; + SmallVector selectorIdx = {1}; -public: - CIRAllocaLowering( - mlir::TypeConverter const &typeConverter, - mlir::DataLayout const &dataLayout, - llvm::StringMap &stringGlobalsMap, - llvm::StringMap &argStringGlobalsMap, - llvm::MapVector &argsVarMap, - mlir::MLIRContext *context) - : OpConversionPattern(typeConverter, context), - dataLayout(dataLayout), stringGlobalsMap(stringGlobalsMap), - argStringGlobalsMap(argStringGlobalsMap), argsVarMap(argsVarMap) {} - - void buildAllocaAnnotations(mlir::LLVM::AllocaOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter, - mlir::ArrayAttr annotationValuesArray) const { - mlir::ModuleOp module = op->getParentOfType(); - mlir::OpBuilder globalVarBuilder(module.getContext()); + if (op.getCleanup()) + padOp.setCleanup(true); - mlir::OpBuilder::InsertPoint afterAlloca = rewriter.saveInsertionPoint(); - globalVarBuilder.setInsertionPointToEnd(&module.getBodyRegion().front()); + mlir::Value slot = + rewriter.create(loc, padOp, slotIdx); + mlir::Value selector = + rewriter.create(loc, padOp, selectorIdx); - mlir::Location loc = op.getLoc(); - mlir::OpBuilder varInitBuilder(module.getContext()); - varInitBuilder.restoreInsertionPoint(afterAlloca); + rewriter.replaceOp(op, mlir::ValueRange{slot, selector}); - auto intrinRetTy = mlir::LLVM::LLVMVoidType::get(getContext()); - constexpr const char *intrinNameAttr = "llvm.var.annotation.p0.p0"; - for (mlir::Attribute entry : annotationValuesArray) { - SmallVector intrinsicArgs; - intrinsicArgs.push_back(op.getRes()); - auto annot = cast(entry); - lowerAnnotationValue(loc, loc, annot, module, varInitBuilder, - globalVarBuilder, stringGlobalsMap, - argStringGlobalsMap, argsVarMap, intrinsicArgs); - rewriter.create( - loc, intrinRetTy, mlir::StringAttr::get(getContext(), intrinNameAttr), - intrinsicArgs); - } + // Landing pads are required to be in LLVM functions with personality + // attribute. FIXME: for now hardcode personality creation in order to start + // adding exception tests, once we annotate CIR with such information, + // change it to be in FuncOp lowering instead. + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + // Insert personality decl before the current function. + rewriter.setInsertionPoint(llvmFn); + auto personalityFnTy = + mlir::LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {}, + /*isVarArg=*/true); + // Get or create `__gxx_personality_v0` + StringRef fnName = "__gxx_personality_v0"; + getOrCreateLLVMFuncOp(rewriter, op, fnName, personalityFnTy); + llvmFn.setPersonality(fnName); } + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Value size = - op.isDynamic() - ? adaptor.getDynAllocSize() - : rewriter.create( - op.getLoc(), - typeConverter->convertType(rewriter.getIndexType()), - rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - auto elementTy = getTypeConverter()->convertType(op.getAllocaType()); - auto resultTy = getTypeConverter()->convertType(op.getResult().getType()); - // Verification between the CIR alloca AS and the one from data layout. - { - auto resPtrTy = mlir::cast(resultTy); - auto dlAllocaASAttr = mlir::cast_if_present( - dataLayout.getAllocaMemorySpace()); - // Absence means 0 - // TODO: The query for the alloca AS should be done through CIRDataLayout - // instead to reuse the logic of interpret null attr as 0. - auto dlAllocaAS = dlAllocaASAttr ? dlAllocaASAttr.getInt() : 0; - if (dlAllocaAS != resPtrTy.getAddressSpace()) { - return op.emitError() << "alloca address space doesn't match the one " - "from the target data layout: " - << dlAllocaAS; - } +void CIRToLLVMAllocaOpLowering::buildAllocaAnnotations( + mlir::LLVM::AllocaOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, + mlir::ArrayAttr annotationValuesArray) const { + mlir::ModuleOp module = op->getParentOfType(); + mlir::OpBuilder globalVarBuilder(module.getContext()); + + mlir::OpBuilder::InsertPoint afterAlloca = rewriter.saveInsertionPoint(); + globalVarBuilder.setInsertionPointToEnd(&module.getBodyRegion().front()); + + mlir::Location loc = op.getLoc(); + mlir::OpBuilder varInitBuilder(module.getContext()); + varInitBuilder.restoreInsertionPoint(afterAlloca); + + auto intrinRetTy = mlir::LLVM::LLVMVoidType::get(getContext()); + constexpr const char *intrinNameAttr = "llvm.var.annotation.p0.p0"; + for (mlir::Attribute entry : annotationValuesArray) { + SmallVector intrinsicArgs; + intrinsicArgs.push_back(op.getRes()); + auto annot = cast(entry); + lowerAnnotationValue(loc, loc, annot, module, varInitBuilder, + globalVarBuilder, stringGlobalsMap, + argStringGlobalsMap, argsVarMap, intrinsicArgs); + rewriter.create( + loc, intrinRetTy, mlir::StringAttr::get(getContext(), intrinNameAttr), + intrinsicArgs); + } +} + +mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite( + cir::AllocaOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Value size = + op.isDynamic() ? adaptor.getDynAllocSize() + : rewriter.create( + op.getLoc(), + typeConverter->convertType(rewriter.getIndexType()), + rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto elementTy = getTypeConverter()->convertType(op.getAllocaType()); + auto resultTy = getTypeConverter()->convertType(op.getResult().getType()); + // Verification between the CIR alloca AS and the one from data layout. + { + auto resPtrTy = mlir::cast(resultTy); + auto dlAllocaASAttr = mlir::cast_if_present( + dataLayout.getAllocaMemorySpace()); + // Absence means 0 + // TODO: The query for the alloca AS should be done through CIRDataLayout + // instead to reuse the logic of interpret null attr as 0. + auto dlAllocaAS = dlAllocaASAttr ? dlAllocaASAttr.getInt() : 0; + if (dlAllocaAS != resPtrTy.getAddressSpace()) { + return op.emitError() << "alloca address space doesn't match the one " + "from the target data layout: " + << dlAllocaAS; } + } - // If there are annotations available, copy them out before we destroy the - // original cir.alloca. - mlir::ArrayAttr annotations; - if (op.getAnnotations()) - annotations = op.getAnnotationsAttr(); + // If there are annotations available, copy them out before we destroy the + // original cir.alloca. + mlir::ArrayAttr annotations; + if (op.getAnnotations()) + annotations = op.getAnnotationsAttr(); - auto llvmAlloca = rewriter.replaceOpWithNewOp( - op, resultTy, elementTy, size, op.getAlignmentAttr().getInt()); + auto llvmAlloca = rewriter.replaceOpWithNewOp( + op, resultTy, elementTy, size, op.getAlignmentAttr().getInt()); - if (annotations && !annotations.empty()) - buildAllocaAnnotations(llvmAlloca, adaptor, rewriter, annotations); - return mlir::success(); - } -}; + if (annotations && !annotations.empty()) + buildAllocaAnnotations(llvmAlloca, adaptor, rewriter, annotations); + return mlir::success(); +} -static mlir::LLVM::AtomicOrdering +mlir::LLVM::AtomicOrdering getLLVMMemOrder(std::optional &memorder) { if (!memorder) return mlir::LLVM::AtomicOrdering::not_atomic; @@ -1582,63 +1453,52 @@ getLLVMMemOrder(std::optional &memorder) { llvm_unreachable("unknown memory order"); } -class CIRLoadLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::LoadOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - const auto llvmTy = - getTypeConverter()->convertType(op.getResult().getType()); - auto memorder = op.getMemOrder(); - auto ordering = getLLVMMemOrder(memorder); - auto alignOpt = op.getAlignment(); - unsigned alignment = 0; - if (!alignOpt) { - mlir::DataLayout layout(op->getParentOfType()); - alignment = (unsigned)layout.getTypeABIAlignment(llvmTy); - } else { - alignment = *alignOpt; - } - - // TODO: nontemporal, invariant, syncscope. - rewriter.replaceOpWithNewOp( - op, llvmTy, adaptor.getAddr(), /* alignment */ alignment, - op.getIsVolatile(), /* nontemporal */ false, - /* invariant */ false, ordering); - return mlir::LogicalResult::success(); +mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite( + cir::LoadOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const auto llvmTy = getTypeConverter()->convertType(op.getResult().getType()); + auto memorder = op.getMemOrder(); + auto ordering = getLLVMMemOrder(memorder); + auto alignOpt = op.getAlignment(); + unsigned alignment = 0; + if (!alignOpt) { + mlir::DataLayout layout(op->getParentOfType()); + alignment = (unsigned)layout.getTypeABIAlignment(llvmTy); + } else { + alignment = *alignOpt; } -}; -class CIRStoreLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::StoreOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto memorder = op.getMemOrder(); - auto ordering = getLLVMMemOrder(memorder); - auto alignOpt = op.getAlignment(); - unsigned alignment = 0; - - if (!alignOpt) { - const auto llvmTy = - getTypeConverter()->convertType(op.getValue().getType()); - mlir::DataLayout layout(op->getParentOfType()); - alignment = (unsigned)layout.getTypeABIAlignment(llvmTy); - } else { - alignment = *alignOpt; - } + // TODO: nontemporal, invariant, syncscope. + rewriter.replaceOpWithNewOp( + op, llvmTy, adaptor.getAddr(), /* alignment */ alignment, + op.getIsVolatile(), /* nontemporal */ false, + /* invariant */ false, ordering); + return mlir::LogicalResult::success(); +} - // TODO: nontemporal, syncscope. - rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), adaptor.getAddr(), alignment, - op.getIsVolatile(), /* nontemporal */ false, ordering); - return mlir::LogicalResult::success(); +mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite( + cir::StoreOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto memorder = op.getMemOrder(); + auto ordering = getLLVMMemOrder(memorder); + auto alignOpt = op.getAlignment(); + unsigned alignment = 0; + + if (!alignOpt) { + const auto llvmTy = + getTypeConverter()->convertType(op.getValue().getType()); + mlir::DataLayout layout(op->getParentOfType()); + alignment = (unsigned)layout.getTypeABIAlignment(llvmTy); + } else { + alignment = *alignOpt; } -}; + + // TODO: nontemporal, syncscope. + rewriter.replaceOpWithNewOp( + op, adaptor.getValue(), adaptor.getAddr(), alignment, op.getIsVolatile(), + /* nontemporal */ false, ordering); + return mlir::LogicalResult::success(); +} bool hasTrailingZeros(cir::ConstArrayAttr attr) { auto array = mlir::dyn_cast(attr.getElts()); @@ -1671,1313 +1531,1152 @@ lowerDataMemberAttr(mlir::ModuleOp moduleOp, cir::DataMemberAttr attr, return mlir::IntegerAttr::get(underlyingIntTy, memberOffset); } -class CIRConstantLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ConstantOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Attribute attr = op.getValue(); - - if (mlir::isa(op.getType())) { - int value = (op.getValue() == - cir::BoolAttr::get(getContext(), - cir::BoolType::get(getContext()), true)); - attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()), - value); - } else if (mlir::isa(op.getType())) { - attr = rewriter.getIntegerAttr( - typeConverter->convertType(op.getType()), - mlir::cast(op.getValue()).getValue()); - } else if (mlir::isa(op.getType())) { - attr = rewriter.getFloatAttr( - typeConverter->convertType(op.getType()), - mlir::cast(op.getValue()).getValue()); - } else if (auto complexTy = - mlir::dyn_cast(op.getType())) { - auto complexAttr = mlir::cast(op.getValue()); - auto complexElemTy = complexTy.getElementTy(); - auto complexElemLLVMTy = typeConverter->convertType(complexElemTy); - - mlir::Attribute components[2]; - if (mlir::isa(complexElemTy)) { - components[0] = rewriter.getIntegerAttr( - complexElemLLVMTy, - mlir::cast(complexAttr.getReal()).getValue()); - components[1] = rewriter.getIntegerAttr( - complexElemLLVMTy, - mlir::cast(complexAttr.getImag()).getValue()); - } else { - components[0] = rewriter.getFloatAttr( - complexElemLLVMTy, - mlir::cast(complexAttr.getReal()).getValue()); - components[1] = rewriter.getFloatAttr( - complexElemLLVMTy, - mlir::cast(complexAttr.getImag()).getValue()); - } +mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( + cir::ConstantOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Attribute attr = op.getValue(); + + if (mlir::isa(op.getType())) { + int value = (op.getValue() == + cir::BoolAttr::get(getContext(), + cir::BoolType::get(getContext()), true)); + attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()), + value); + } else if (mlir::isa(op.getType())) { + attr = rewriter.getIntegerAttr( + typeConverter->convertType(op.getType()), + mlir::cast(op.getValue()).getValue()); + } else if (mlir::isa(op.getType())) { + attr = rewriter.getFloatAttr( + typeConverter->convertType(op.getType()), + mlir::cast(op.getValue()).getValue()); + } else if (auto complexTy = mlir::dyn_cast(op.getType())) { + auto complexAttr = mlir::cast(op.getValue()); + auto complexElemTy = complexTy.getElementTy(); + auto complexElemLLVMTy = typeConverter->convertType(complexElemTy); + + mlir::Attribute components[2]; + if (mlir::isa(complexElemTy)) { + components[0] = rewriter.getIntegerAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getReal()).getValue()); + components[1] = rewriter.getIntegerAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getImag()).getValue()); + } else { + components[0] = rewriter.getFloatAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getReal()).getValue()); + components[1] = rewriter.getFloatAttr( + complexElemLLVMTy, + mlir::cast(complexAttr.getImag()).getValue()); + } - attr = rewriter.getArrayAttr(components); - } else if (mlir::isa(op.getType())) { - // Optimize with dedicated LLVM op for null pointers. - if (mlir::isa(op.getValue())) { - if (mlir::cast(op.getValue()).isNullValue()) { - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType())); - return mlir::success(); - } - } - // Lower GlobalViewAttr to llvm.mlir.addressof - if (auto gv = mlir::dyn_cast(op.getValue())) { - auto newOp = lowerCirAttrAsValue(op, gv, rewriter, getTypeConverter()); - rewriter.replaceOp(op, newOp); + attr = rewriter.getArrayAttr(components); + } else if (mlir::isa(op.getType())) { + // Optimize with dedicated LLVM op for null pointers. + if (mlir::isa(op.getValue())) { + if (mlir::cast(op.getValue()).isNullValue()) { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType())); return mlir::success(); } - attr = op.getValue(); - } else if (mlir::isa(op.getType())) { - auto dataMember = mlir::cast(op.getValue()); - attr = lowerDataMemberAttr(op->getParentOfType(), - dataMember, *typeConverter); } - // TODO(cir): constant arrays are currently just pushed into the stack using - // the store instruction, instead of being stored as global variables and - // then memcopyied into the stack (as done in Clang). - else if (auto arrTy = mlir::dyn_cast(op.getType())) { - // Fetch operation constant array initializer. - - auto constArr = mlir::dyn_cast(op.getValue()); - if (!constArr && !isa(op.getValue())) - return op.emitError() << "array does not have a constant initializer"; - - std::optional denseAttr; - if (constArr && hasTrailingZeros(constArr)) { - auto newOp = - lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter()); - rewriter.replaceOp(op, newOp); - return mlir::success(); - } else if (constArr && - (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) { - attr = denseAttr.value(); - } else { - auto initVal = - lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter); - rewriter.replaceAllUsesWith(op, initVal); - rewriter.eraseOp(op); - return mlir::success(); - } - } else if (const auto structAttr = - mlir::dyn_cast(op.getValue())) { - // TODO(cir): this diverges from traditional lowering. Normally the - // initializer would be a global constant that is memcopied. Here we just - // define a local constant with llvm.undef that will be stored into the - // stack. + // Lower GlobalViewAttr to llvm.mlir.addressof + if (auto gv = mlir::dyn_cast(op.getValue())) { + auto newOp = lowerCirAttrAsValue(op, gv, rewriter, getTypeConverter()); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } + attr = op.getValue(); + } else if (mlir::isa(op.getType())) { + auto dataMember = mlir::cast(op.getValue()); + attr = lowerDataMemberAttr(op->getParentOfType(), + dataMember, *typeConverter); + } + // TODO(cir): constant arrays are currently just pushed into the stack using + // the store instruction, instead of being stored as global variables and + // then memcopyied into the stack (as done in Clang). + else if (auto arrTy = mlir::dyn_cast(op.getType())) { + // Fetch operation constant array initializer. + + auto constArr = mlir::dyn_cast(op.getValue()); + if (!constArr && !isa(op.getValue())) + return op.emitError() << "array does not have a constant initializer"; + + std::optional denseAttr; + if (constArr && hasTrailingZeros(constArr)) { + auto newOp = + lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter()); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } else if (constArr && + (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) { + attr = denseAttr.value(); + } else { auto initVal = - lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter); + lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter); rewriter.replaceAllUsesWith(op, initVal); rewriter.eraseOp(op); return mlir::success(); - } else if (auto strTy = mlir::dyn_cast(op.getType())) { - auto attr = op.getValue(); - if (mlir::isa(attr)) { - auto initVal = lowerCirAttrAsValue(op, attr, rewriter, typeConverter); - rewriter.replaceAllUsesWith(op, initVal); - rewriter.eraseOp(op); - return mlir::success(); - } - - return op.emitError() << "unsupported lowering for struct constant type " - << op.getType(); - } else if (const auto vecTy = - mlir::dyn_cast(op.getType())) { - rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter, - getTypeConverter())); + } + } else if (const auto structAttr = + mlir::dyn_cast(op.getValue())) { + // TODO(cir): this diverges from traditional lowering. Normally the + // initializer would be a global constant that is memcopied. Here we just + // define a local constant with llvm.undef that will be stored into the + // stack. + auto initVal = lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter); + rewriter.replaceAllUsesWith(op, initVal); + rewriter.eraseOp(op); + return mlir::success(); + } else if (auto strTy = mlir::dyn_cast(op.getType())) { + auto attr = op.getValue(); + if (mlir::isa(attr)) { + auto initVal = lowerCirAttrAsValue(op, attr, rewriter, typeConverter); + rewriter.replaceAllUsesWith(op, initVal); + rewriter.eraseOp(op); return mlir::success(); - } else - return op.emitError() << "unsupported constant type " << op.getType(); - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), attr); + } + return op.emitError() << "unsupported lowering for struct constant type " + << op.getType(); + } else if (const auto vecTy = mlir::dyn_cast(op.getType())) { + rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter, + getTypeConverter())); return mlir::success(); - } -}; + } else + return op.emitError() << "unsupported constant type " << op.getType(); -class CIRVectorCreateLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VecCreateOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Start with an 'undef' value for the vector. Then 'insertelement' for - // each of the vector elements. - auto vecTy = mlir::dyn_cast(op.getType()); - assert(vecTy && "result type of cir.vec.create op is not VectorType"); - auto llvmTy = typeConverter->convertType(vecTy); - auto loc = op.getLoc(); - mlir::Value result = rewriter.create(loc, llvmTy); - assert(vecTy.getSize() == op.getElements().size() && - "cir.vec.create op count doesn't match vector type elements count"); - for (uint64_t i = 0; i < vecTy.getSize(); ++i) { - mlir::Value indexValue = rewriter.create( - loc, rewriter.getI64Type(), i); - result = rewriter.create( - loc, result, adaptor.getElements()[i], indexValue); - } - rewriter.replaceOp(op, result); - return mlir::success(); - } -}; + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), attr); -class CIRVectorCmpOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VecCmpOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - assert(mlir::isa(op.getType()) && - mlir::isa(op.getLhs().getType()) && - mlir::isa(op.getRhs().getType()) && - "Vector compare with non-vector type"); - // LLVM IR vector comparison returns a vector of i1. This one-bit vector - // must be sign-extended to the correct result type. - auto elementType = elementTypeIfVector(op.getLhs().getType()); - mlir::Value bitResult; - if (auto intType = mlir::dyn_cast(elementType)) { - bitResult = rewriter.create( - op.getLoc(), - convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()), - adaptor.getLhs(), adaptor.getRhs()); - } else if (mlir::isa(elementType)) { - bitResult = rewriter.create( - op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()), - adaptor.getLhs(), adaptor.getRhs()); - } else { - return op.emitError() << "unsupported type for VecCmpOp: " << elementType; - } - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), bitResult); - return mlir::success(); - } -}; + return mlir::success(); +} -class CIRVectorSplatLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VecSplatOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Vector splat can be implemented with an `insertelement` and a - // `shufflevector`, which is better than an `insertelement` for each - // element in the vector. Start with an undef vector. Insert the value into - // the first element. Then use a `shufflevector` with a mask of all 0 to - // fill out the entire vector with that value. - auto vecTy = mlir::dyn_cast(op.getType()); - assert(vecTy && "result type of cir.vec.splat op is not VectorType"); - auto llvmTy = typeConverter->convertType(vecTy); - auto loc = op.getLoc(); - mlir::Value undef = rewriter.create(loc, llvmTy); +mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite( + cir::VecCreateOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Start with an 'undef' value for the vector. Then 'insertelement' for + // each of the vector elements. + auto vecTy = mlir::dyn_cast(op.getType()); + assert(vecTy && "result type of cir.vec.create op is not VectorType"); + auto llvmTy = typeConverter->convertType(vecTy); + auto loc = op.getLoc(); + mlir::Value result = rewriter.create(loc, llvmTy); + assert(vecTy.getSize() == op.getElements().size() && + "cir.vec.create op count doesn't match vector type elements count"); + for (uint64_t i = 0; i < vecTy.getSize(); ++i) { mlir::Value indexValue = - rewriter.create(loc, rewriter.getI64Type(), 0); - mlir::Value elementValue = adaptor.getValue(); - mlir::Value oneElement = rewriter.create( - loc, undef, elementValue, indexValue); - SmallVector zeroValues(vecTy.getSize(), 0); - mlir::Value shuffled = rewriter.create( - loc, oneElement, undef, zeroValues); - rewriter.replaceOp(op, shuffled); - return mlir::success(); + rewriter.create(loc, rewriter.getI64Type(), i); + result = rewriter.create( + loc, result, adaptor.getElements()[i], indexValue); } -}; + rewriter.replaceOp(op, result); + return mlir::success(); +} -class CIRVectorTernaryLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VecTernaryOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - assert(mlir::isa(op.getType()) && - mlir::isa(op.getCond().getType()) && - mlir::isa(op.getVec1().getType()) && - mlir::isa(op.getVec2().getType()) && - "Vector ternary op with non-vector type"); - // Convert `cond` into a vector of i1, then use that in a `select` op. - mlir::Value bitVec = rewriter.create( - op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), - rewriter.create( - op.getCond().getLoc(), - typeConverter->convertType(op.getCond().getType()))); - rewriter.replaceOpWithNewOp( - op, bitVec, adaptor.getVec1(), adaptor.getVec2()); - return mlir::success(); +mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite( + cir::VecCmpOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + assert(mlir::isa(op.getType()) && + mlir::isa(op.getLhs().getType()) && + mlir::isa(op.getRhs().getType()) && + "Vector compare with non-vector type"); + // LLVM IR vector comparison returns a vector of i1. This one-bit vector + // must be sign-extended to the correct result type. + auto elementType = elementTypeIfVector(op.getLhs().getType()); + mlir::Value bitResult; + if (auto intType = mlir::dyn_cast(elementType)) { + bitResult = rewriter.create( + op.getLoc(), + convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()), + adaptor.getLhs(), adaptor.getRhs()); + } else if (mlir::isa(elementType)) { + bitResult = rewriter.create( + op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()), + adaptor.getLhs(), adaptor.getRhs()); + } else { + return op.emitError() << "unsupported type for VecCmpOp: " << elementType; } -}; + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), bitResult); + return mlir::success(); +} -class CIRVectorShuffleIntsLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VecShuffleOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices. - // Convert the ClangIR ArrayAttr of IntAttr constants into a - // SmallVector. - SmallVector indices; - std::transform( - op.getIndices().begin(), op.getIndices().end(), - std::back_inserter(indices), [](mlir::Attribute intAttr) { - return mlir::cast(intAttr).getValue().getSExtValue(); - }); - rewriter.replaceOpWithNewOp( - op, adaptor.getVec1(), adaptor.getVec2(), indices); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite( + cir::VecSplatOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Vector splat can be implemented with an `insertelement` and a + // `shufflevector`, which is better than an `insertelement` for each + // element in the vector. Start with an undef vector. Insert the value into + // the first element. Then use a `shufflevector` with a mask of all 0 to + // fill out the entire vector with that value. + auto vecTy = mlir::dyn_cast(op.getType()); + assert(vecTy && "result type of cir.vec.splat op is not VectorType"); + auto llvmTy = typeConverter->convertType(vecTy); + auto loc = op.getLoc(); + mlir::Value undef = rewriter.create(loc, llvmTy); + mlir::Value indexValue = + rewriter.create(loc, rewriter.getI64Type(), 0); + mlir::Value elementValue = adaptor.getValue(); + mlir::Value oneElement = rewriter.create( + loc, undef, elementValue, indexValue); + SmallVector zeroValues(vecTy.getSize(), 0); + mlir::Value shuffled = rewriter.create( + loc, oneElement, undef, zeroValues); + rewriter.replaceOp(op, shuffled); + return mlir::success(); +} -class CIRVectorShuffleVecLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VecShuffleDynamicOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // LLVM IR does not have an operation that corresponds to this form of - // the built-in. - // __builtin_shufflevector(V, I) - // is implemented as this pseudocode, where the for loop is unrolled - // and N is the number of elements: - // masked = I & (N-1) - // for (i in 0 <= i < N) - // result[i] = V[masked[i]] - auto loc = op.getLoc(); - mlir::Value input = adaptor.getVec(); - mlir::Type llvmIndexVecType = - getTypeConverter()->convertType(op.getIndices().getType()); - mlir::Type llvmIndexType = getTypeConverter()->convertType( - elementTypeIfVector(op.getIndices().getType())); - uint64_t numElements = - mlir::cast(op.getVec().getType()).getSize(); - mlir::Value maskValue = rewriter.create( - loc, llvmIndexType, - mlir::IntegerAttr::get(llvmIndexType, numElements - 1)); - mlir::Value maskVector = - rewriter.create(loc, llvmIndexVecType); - for (uint64_t i = 0; i < numElements; ++i) { - mlir::Value iValue = rewriter.create( - loc, rewriter.getI64Type(), i); - maskVector = rewriter.create( - loc, maskVector, maskValue, iValue); - } - mlir::Value maskedIndices = rewriter.create( - loc, llvmIndexVecType, adaptor.getIndices(), maskVector); - mlir::Value result = rewriter.create( - loc, getTypeConverter()->convertType(op.getVec().getType())); - for (uint64_t i = 0; i < numElements; ++i) { - mlir::Value iValue = rewriter.create( - loc, rewriter.getI64Type(), i); - mlir::Value indexValue = rewriter.create( - loc, maskedIndices, iValue); - mlir::Value valueAtIndex = - rewriter.create(loc, input, indexValue); - result = rewriter.create( - loc, result, valueAtIndex, iValue); - } - rewriter.replaceOp(op, result); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite( + cir::VecTernaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + assert(mlir::isa(op.getType()) && + mlir::isa(op.getCond().getType()) && + mlir::isa(op.getVec1().getType()) && + mlir::isa(op.getVec2().getType()) && + "Vector ternary op with non-vector type"); + // Convert `cond` into a vector of i1, then use that in a `select` op. + mlir::Value bitVec = rewriter.create( + op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), + rewriter.create( + op.getCond().getLoc(), + typeConverter->convertType(op.getCond().getType()))); + rewriter.replaceOpWithNewOp( + op, bitVec, adaptor.getVec1(), adaptor.getVec2()); + return mlir::success(); +} -class CIRVAStartLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VAStartOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); - auto vaList = rewriter.create( - op.getLoc(), opaquePtr, adaptor.getOperands().front()); - rewriter.replaceOpWithNewOp(op, vaList); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite( + cir::VecShuffleOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices. + // Convert the ClangIR ArrayAttr of IntAttr constants into a + // SmallVector. + SmallVector indices; + std::transform( + op.getIndices().begin(), op.getIndices().end(), + std::back_inserter(indices), [](mlir::Attribute intAttr) { + return mlir::cast(intAttr).getValue().getSExtValue(); + }); + rewriter.replaceOpWithNewOp( + op, adaptor.getVec1(), adaptor.getVec2(), indices); + return mlir::success(); +} -class CIRVAEndLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VAEndOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); - auto vaList = rewriter.create( - op.getLoc(), opaquePtr, adaptor.getOperands().front()); - rewriter.replaceOpWithNewOp(op, vaList); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite( + cir::VecShuffleDynamicOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // LLVM IR does not have an operation that corresponds to this form of + // the built-in. + // __builtin_shufflevector(V, I) + // is implemented as this pseudocode, where the for loop is unrolled + // and N is the number of elements: + // masked = I & (N-1) + // for (i in 0 <= i < N) + // result[i] = V[masked[i]] + auto loc = op.getLoc(); + mlir::Value input = adaptor.getVec(); + mlir::Type llvmIndexVecType = + getTypeConverter()->convertType(op.getIndices().getType()); + mlir::Type llvmIndexType = getTypeConverter()->convertType( + elementTypeIfVector(op.getIndices().getType())); + uint64_t numElements = + mlir::cast(op.getVec().getType()).getSize(); + mlir::Value maskValue = rewriter.create( + loc, llvmIndexType, + mlir::IntegerAttr::get(llvmIndexType, numElements - 1)); + mlir::Value maskVector = + rewriter.create(loc, llvmIndexVecType); + for (uint64_t i = 0; i < numElements; ++i) { + mlir::Value iValue = + rewriter.create(loc, rewriter.getI64Type(), i); + maskVector = rewriter.create( + loc, maskVector, maskValue, iValue); + } + mlir::Value maskedIndices = rewriter.create( + loc, llvmIndexVecType, adaptor.getIndices(), maskVector); + mlir::Value result = rewriter.create( + loc, getTypeConverter()->convertType(op.getVec().getType())); + for (uint64_t i = 0; i < numElements; ++i) { + mlir::Value iValue = + rewriter.create(loc, rewriter.getI64Type(), i); + mlir::Value indexValue = rewriter.create( + loc, maskedIndices, iValue); + mlir::Value valueAtIndex = + rewriter.create(loc, input, indexValue); + result = rewriter.create(loc, result, + valueAtIndex, iValue); + } + rewriter.replaceOp(op, result); + return mlir::success(); +} -class CIRVACopyLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VACopyOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); - auto dstList = rewriter.create( - op.getLoc(), opaquePtr, adaptor.getOperands().front()); - auto srcList = rewriter.create( - op.getLoc(), opaquePtr, adaptor.getOperands().back()); - rewriter.replaceOpWithNewOp(op, dstList, srcList); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMVAStartOpLowering::matchAndRewrite( + cir::VAStartOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); + auto vaList = rewriter.create( + op.getLoc(), opaquePtr, adaptor.getOperands().front()); + rewriter.replaceOpWithNewOp(op, vaList); + return mlir::success(); +} -class CIRVAArgLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMVAEndOpLowering::matchAndRewrite( + cir::VAEndOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); + auto vaList = rewriter.create( + op.getLoc(), opaquePtr, adaptor.getOperands().front()); + rewriter.replaceOpWithNewOp(op, vaList); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::VAArgOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - return op.emitError("cir.vaarg lowering is NYI"); - } -}; +mlir::LogicalResult CIRToLLVMVACopyOpLowering::matchAndRewrite( + cir::VACopyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); + auto dstList = rewriter.create( + op.getLoc(), opaquePtr, adaptor.getOperands().front()); + auto srcList = rewriter.create( + op.getLoc(), opaquePtr, adaptor.getOperands().back()); + rewriter.replaceOpWithNewOp(op, dstList, srcList); + return mlir::success(); +} -class CIRFuncLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMVAArgOpLowering::matchAndRewrite( + cir::VAArgOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + return op.emitError("cir.vaarg lowering is NYI"); +} /// Returns the name used for the linkage attribute. This *must* correspond /// to the name of the attribute in ODS. - static StringRef getLinkageAttrNameString() { return "linkage"; } - - /// Convert the `cir.func` attributes to `llvm.func` attributes. - /// Only retain those attributes that are not constructed by - /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out - /// argument attributes. - void - lowerFuncAttributes(cir::FuncOp func, bool filterArgAndResAttrs, - SmallVectorImpl &result) const { - for (auto attr : func->getAttrs()) { - if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() || - attr.getName() == func.getFunctionTypeAttrName() || - attr.getName() == getLinkageAttrNameString() || - attr.getName() == func.getCallingConvAttrName() || - (filterArgAndResAttrs && - (attr.getName() == func.getArgAttrsAttrName() || - attr.getName() == func.getResAttrsAttrName()))) - continue; - - // `CIRDialectLLVMIRTranslationInterface` requires "cir." prefix for - // dialect specific attributes, rename them. - if (attr.getName() == func.getExtraAttrsAttrName()) { - std::string cirName = "cir." + func.getExtraAttrsAttrName().str(); - attr.setName(mlir::StringAttr::get(getContext(), cirName)); +StringRef CIRToLLVMFuncOpLowering::getLinkageAttrNameString() { + return "linkage"; +} - lowerFuncOpenCLKernelMetadata(attr); - } - result.push_back(attr); +/// Convert the `cir.func` attributes to `llvm.func` attributes. +/// Only retain those attributes that are not constructed by +/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out +/// argument attributes. +void CIRToLLVMFuncOpLowering::lowerFuncAttributes( + cir::FuncOp func, bool filterArgAndResAttrs, + SmallVectorImpl &result) const { + for (auto attr : func->getAttrs()) { + if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() || + attr.getName() == func.getFunctionTypeAttrName() || + attr.getName() == getLinkageAttrNameString() || + attr.getName() == func.getCallingConvAttrName() || + (filterArgAndResAttrs && + (attr.getName() == func.getArgAttrsAttrName() || + attr.getName() == func.getResAttrsAttrName()))) + continue; + + // `CIRDialectLLVMIRTranslationInterface` requires "cir." prefix for + // dialect specific attributes, rename them. + if (attr.getName() == func.getExtraAttrsAttrName()) { + std::string cirName = "cir." + func.getExtraAttrsAttrName().str(); + attr.setName(mlir::StringAttr::get(getContext(), cirName)); + + lowerFuncOpenCLKernelMetadata(attr); } + result.push_back(attr); } +} /// When do module translation, we can only translate LLVM-compatible types. /// Here we lower possible OpenCLKernelMetadataAttr to use the converted type. - void - lowerFuncOpenCLKernelMetadata(mlir::NamedAttribute &extraAttrsEntry) const { - const auto attrKey = cir::OpenCLKernelMetadataAttr::getMnemonic(); - auto oldExtraAttrs = - cast(extraAttrsEntry.getValue()); - if (!oldExtraAttrs.getElements().contains(attrKey)) - return; +void CIRToLLVMFuncOpLowering::lowerFuncOpenCLKernelMetadata( + mlir::NamedAttribute &extraAttrsEntry) const { + const auto attrKey = cir::OpenCLKernelMetadataAttr::getMnemonic(); + auto oldExtraAttrs = + cast(extraAttrsEntry.getValue()); + if (!oldExtraAttrs.getElements().contains(attrKey)) + return; - mlir::NamedAttrList newExtraAttrs; - for (auto entry : oldExtraAttrs.getElements()) { - if (entry.getName() == attrKey) { - auto clKernelMetadata = - cast(entry.getValue()); - if (auto vecTypeHint = clKernelMetadata.getVecTypeHint()) { - auto newType = typeConverter->convertType(vecTypeHint.getValue()); - auto newTypeHint = mlir::TypeAttr::get(newType); - auto newCLKMAttr = cir::OpenCLKernelMetadataAttr::get( - getContext(), clKernelMetadata.getWorkGroupSizeHint(), - clKernelMetadata.getReqdWorkGroupSize(), newTypeHint, - clKernelMetadata.getVecTypeHintSignedness(), - clKernelMetadata.getIntelReqdSubGroupSize()); - entry.setValue(newCLKMAttr); - } + mlir::NamedAttrList newExtraAttrs; + for (auto entry : oldExtraAttrs.getElements()) { + if (entry.getName() == attrKey) { + auto clKernelMetadata = + cast(entry.getValue()); + if (auto vecTypeHint = clKernelMetadata.getVecTypeHint()) { + auto newType = typeConverter->convertType(vecTypeHint.getValue()); + auto newTypeHint = mlir::TypeAttr::get(newType); + auto newCLKMAttr = cir::OpenCLKernelMetadataAttr::get( + getContext(), clKernelMetadata.getWorkGroupSizeHint(), + clKernelMetadata.getReqdWorkGroupSize(), newTypeHint, + clKernelMetadata.getVecTypeHintSignedness(), + clKernelMetadata.getIntelReqdSubGroupSize()); + entry.setValue(newCLKMAttr); } - newExtraAttrs.push_back(entry); } - extraAttrsEntry.setValue(cir::ExtraFuncAttributesAttr::get( - getContext(), newExtraAttrs.getDictionary(getContext()))); + newExtraAttrs.push_back(entry); } + extraAttrsEntry.setValue(cir::ExtraFuncAttributesAttr::get( + getContext(), newExtraAttrs.getDictionary(getContext()))); +} - mlir::LogicalResult - matchAndRewrite(cir::FuncOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - - auto fnType = op.getFunctionType(); - auto isDsoLocal = op.getDsolocal(); - mlir::TypeConverter::SignatureConversion signatureConversion( - fnType.getNumInputs()); +mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite( + cir::FuncOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { - for (const auto &argType : enumerate(fnType.getInputs())) { - auto convertedType = typeConverter->convertType(argType.value()); - if (!convertedType) - return mlir::failure(); - signatureConversion.addInputs(argType.index(), convertedType); - } + auto fnType = op.getFunctionType(); + auto isDsoLocal = op.getDsolocal(); + mlir::TypeConverter::SignatureConversion signatureConversion( + fnType.getNumInputs()); - mlir::Type resultType = - getTypeConverter()->convertType(fnType.getReturnType()); - - // Create the LLVM function operation. - auto llvmFnTy = mlir::LLVM::LLVMFunctionType::get( - resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()), - signatureConversion.getConvertedTypes(), - /*isVarArg=*/fnType.isVarArg()); - // LLVMFuncOp expects a single FileLine Location instead of a fused - // location. - auto Loc = op.getLoc(); - if (mlir::isa(Loc)) { - auto FusedLoc = mlir::cast(Loc); - Loc = FusedLoc.getLocations()[0]; - } - assert((mlir::isa(Loc) || - mlir::isa(Loc)) && - "expected single location or unknown location here"); - - auto linkage = convertLinkage(op.getLinkage()); - auto cconv = convertCallingConv(op.getCallingConv()); - SmallVector attributes; - lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); - - auto fn = rewriter.create( - Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, - mlir::SymbolRefAttr(), attributes); - - fn.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get( - getContext(), lowerCIRVisibilityToLLVMVisibility( - op.getGlobalVisibilityAttr().getValue()))); - - rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); - if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter, - &signatureConversion))) + for (const auto &argType : enumerate(fnType.getInputs())) { + auto convertedType = typeConverter->convertType(argType.value()); + if (!convertedType) return mlir::failure(); + signatureConversion.addInputs(argType.index(), convertedType); + } + + mlir::Type resultType = + getTypeConverter()->convertType(fnType.getReturnType()); + + // Create the LLVM function operation. + auto llvmFnTy = mlir::LLVM::LLVMFunctionType::get( + resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()), + signatureConversion.getConvertedTypes(), + /*isVarArg=*/fnType.isVarArg()); + // LLVMFuncOp expects a single FileLine Location instead of a fused + // location. + auto Loc = op.getLoc(); + if (mlir::isa(Loc)) { + auto FusedLoc = mlir::cast(Loc); + Loc = FusedLoc.getLocations()[0]; + } + assert((mlir::isa(Loc) || + mlir::isa(Loc)) && + "expected single location or unknown location here"); + + auto linkage = convertLinkage(op.getLinkage()); + auto cconv = convertCallingConv(op.getCallingConv()); + SmallVector attributes; + lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); + + auto fn = rewriter.create( + Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, + mlir::SymbolRefAttr(), attributes); + + fn.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get( + getContext(), lowerCIRVisibilityToLLVMVisibility( + op.getGlobalVisibilityAttr().getValue()))); + + rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); + if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter, + &signatureConversion))) + return mlir::failure(); - rewriter.eraseOp(op); - - return mlir::LogicalResult::success(); - } -}; - -class CIRGetGlobalOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::GetGlobalOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // FIXME(cir): Premature DCE to avoid lowering stuff we're not using. - // CIRGen should mitigate this and not emit the get_global. - if (op->getUses().empty()) { - rewriter.eraseOp(op); - return mlir::success(); - } - - auto type = getTypeConverter()->convertType(op.getType()); - auto symbol = op.getName(); - mlir::Operation *newop = - rewriter.create(op.getLoc(), type, symbol); + rewriter.eraseOp(op); - if (op.getTls()) { - // Handle access to TLS via intrinsic. - newop = rewriter.create( - op.getLoc(), type, newop->getResult(0)); - } + return mlir::LogicalResult::success(); +} - rewriter.replaceOp(op, newop); +mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite( + cir::GetGlobalOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // FIXME(cir): Premature DCE to avoid lowering stuff we're not using. + // CIRGen should mitigate this and not emit the get_global. + if (op->getUses().empty()) { + rewriter.eraseOp(op); return mlir::success(); } -}; -class CIRComplexCreateOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + auto type = getTypeConverter()->convertType(op.getType()); + auto symbol = op.getName(); + mlir::Operation *newop = + rewriter.create(op.getLoc(), type, symbol); - mlir::LogicalResult - matchAndRewrite(cir::ComplexCreateOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto complexLLVMTy = - getTypeConverter()->convertType(op.getResult().getType()); - auto initialComplex = - rewriter.create(op->getLoc(), complexLLVMTy); + if (op.getTls()) { + // Handle access to TLS via intrinsic. + newop = rewriter.create( + op.getLoc(), type, newop->getResult(0)); + } - int64_t position[1]{0}; - auto realComplex = rewriter.create( - op->getLoc(), initialComplex, adaptor.getReal(), position); + rewriter.replaceOp(op, newop); + return mlir::success(); +} - position[0] = 1; - auto complex = rewriter.create( - op->getLoc(), realComplex, adaptor.getImag(), position); +mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite( + cir::ComplexCreateOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto complexLLVMTy = + getTypeConverter()->convertType(op.getResult().getType()); + auto initialComplex = + rewriter.create(op->getLoc(), complexLLVMTy); - rewriter.replaceOp(op, complex); - return mlir::success(); - } -}; + int64_t position[1]{0}; + auto realComplex = rewriter.create( + op->getLoc(), initialComplex, adaptor.getReal(), position); -class CIRComplexRealOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ComplexRealOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resultLLVMTy = - getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOpWithNewOp( - op, resultLLVMTy, adaptor.getOperand(), - llvm::ArrayRef{0}); - return mlir::success(); - } -}; + position[0] = 1; + auto complex = rewriter.create( + op->getLoc(), realComplex, adaptor.getImag(), position); -class CIRComplexImagOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ComplexImagOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resultLLVMTy = - getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOpWithNewOp( - op, resultLLVMTy, adaptor.getOperand(), - llvm::ArrayRef{1}); - return mlir::success(); - } -}; + rewriter.replaceOp(op, complex); + return mlir::success(); +} -class CIRComplexRealPtrOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ComplexRealPtrOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto operandTy = mlir::cast(op.getOperand().getType()); - auto resultLLVMTy = - getTypeConverter()->convertType(op.getResult().getType()); - auto elementLLVMTy = - getTypeConverter()->convertType(operandTy.getPointee()); - - mlir::LLVM::GEPArg gepIndices[2]{{0}, {0}}; - rewriter.replaceOpWithNewOp( - op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, - /*inbounds=*/true); +mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite( + cir::ComplexRealOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resultLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef{0}); + return mlir::success(); +} - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite( + cir::ComplexImagOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resultLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef{1}); + return mlir::success(); +} -class CIRComplexImagPtrOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ComplexImagPtrOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto operandTy = mlir::cast(op.getOperand().getType()); - auto resultLLVMTy = - getTypeConverter()->convertType(op.getResult().getType()); - auto elementLLVMTy = - getTypeConverter()->convertType(operandTy.getPointee()); - - mlir::LLVM::GEPArg gepIndices[2]{{0}, {1}}; - rewriter.replaceOpWithNewOp( - op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, - /*inbounds=*/true); +mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite( + cir::ComplexRealPtrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto operandTy = mlir::cast(op.getOperand().getType()); + auto resultLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); + auto elementLLVMTy = getTypeConverter()->convertType(operandTy.getPointee()); - return mlir::success(); - } -}; + mlir::LLVM::GEPArg gepIndices[2]{{0}, {0}}; + rewriter.replaceOpWithNewOp( + op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, + /*inbounds=*/true); -class CIRSwitchFlatOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::SwitchFlatOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { +mlir::LogicalResult CIRToLLVMComplexImagPtrOpLowering::matchAndRewrite( + cir::ComplexImagPtrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto operandTy = mlir::cast(op.getOperand().getType()); + auto resultLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); + auto elementLLVMTy = getTypeConverter()->convertType(operandTy.getPointee()); - llvm::SmallVector caseValues; - if (op.getCaseValues()) { - for (auto val : op.getCaseValues()) { - auto intAttr = dyn_cast(val); - caseValues.push_back(intAttr.getValue()); - } - } + mlir::LLVM::GEPArg gepIndices[2]{{0}, {1}}; + rewriter.replaceOpWithNewOp( + op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, + /*inbounds=*/true); - llvm::SmallVector caseDestinations; - llvm::SmallVector caseOperands; + return mlir::success(); +} - for (auto x : op.getCaseDestinations()) { - caseDestinations.push_back(x); - } +mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite( + cir::SwitchFlatOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { - for (auto x : op.getCaseOperands()) { - caseOperands.push_back(x); + llvm::SmallVector caseValues; + if (op.getCaseValues()) { + for (auto val : op.getCaseValues()) { + auto intAttr = dyn_cast(val); + caseValues.push_back(intAttr.getValue()); } + } - // Set switch op to branch to the newly created blocks. - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), op.getDefaultDestination(), - op.getDefaultOperands(), caseValues, caseDestinations, caseOperands); - return mlir::success(); + llvm::SmallVector caseDestinations; + llvm::SmallVector caseOperands; + + for (auto x : op.getCaseDestinations()) { + caseDestinations.push_back(x); } -}; -class CIRGlobalOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + for (auto x : op.getCaseOperands()) { + caseOperands.push_back(x); + } + + // Set switch op to branch to the newly created blocks. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getDefaultDestination(), + op.getDefaultOperands(), caseValues, caseDestinations, caseOperands); + return mlir::success(); +} /// Replace CIR global with a region initialized LLVM global and update /// insertion point to the end of the initializer block. - inline void setupRegionInitializedLLVMGlobalOp( - cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { - const auto llvmType = getTypeConverter()->convertType(op.getSymType()); - SmallVector attributes; - auto newGlobalOp = rewriter.replaceOpWithNewOp( - op, llvmType, op.getConstant(), convertLinkage(op.getLinkage()), - op.getSymName(), nullptr, - /*alignment*/ op.getAlignment().value_or(0), +void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( + cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { + const auto llvmType = getTypeConverter()->convertType(op.getSymType()); + SmallVector attributes; + auto newGlobalOp = rewriter.replaceOpWithNewOp( + op, llvmType, op.getConstant(), convertLinkage(op.getLinkage()), + op.getSymName(), nullptr, + /*alignment*/ op.getAlignment().value_or(0), + /*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op), + /*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(), + /*comdat*/ mlir::SymbolRefAttr(), attributes); + newGlobalOp.getRegion().push_back(new mlir::Block()); + rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); +} + +mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( + cir::GlobalOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + + // Fetch required values to create LLVM op. + const auto llvmType = getTypeConverter()->convertType(op.getSymType()); + const auto isConst = op.getConstant(); + const auto isDsoLocal = op.getDsolocal(); + const auto linkage = convertLinkage(op.getLinkage()); + const auto symbol = op.getSymName(); + const auto loc = op.getLoc(); + std::optional section = op.getSection(); + std::optional init = op.getInitialValue(); + mlir::LLVM::VisibilityAttr visibility = mlir::LLVM::VisibilityAttr::get( + getContext(), lowerCIRVisibilityToLLVMVisibility( + op.getGlobalVisibilityAttr().getValue())); + + SmallVector attributes; + if (section.has_value()) + attributes.push_back(rewriter.getNamedAttr( + "section", rewriter.getStringAttr(section.value()))); + + attributes.push_back(rewriter.getNamedAttr("visibility_", visibility)); + + // Check for missing funcionalities. + if (!init.has_value()) { + rewriter.replaceOpWithNewOp( + op, llvmType, isConst, linkage, symbol, mlir::Attribute(), + /*alignment*/ 0, /*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op), - /*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(), + /*dsoLocal*/ isDsoLocal, /*threadLocal*/ (bool)op.getTlsModelAttr(), /*comdat*/ mlir::SymbolRefAttr(), attributes); - newGlobalOp.getRegion().push_back(new mlir::Block()); - rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); + return mlir::success(); } - mlir::LogicalResult - matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - - // Fetch required values to create LLVM op. - const auto llvmType = getTypeConverter()->convertType(op.getSymType()); - const auto isConst = op.getConstant(); - const auto isDsoLocal = op.getDsolocal(); - const auto linkage = convertLinkage(op.getLinkage()); - const auto symbol = op.getSymName(); - const auto loc = op.getLoc(); - std::optional section = op.getSection(); - std::optional init = op.getInitialValue(); - mlir::LLVM::VisibilityAttr visibility = mlir::LLVM::VisibilityAttr::get( - getContext(), lowerCIRVisibilityToLLVMVisibility( - op.getGlobalVisibilityAttr().getValue())); - - SmallVector attributes; - if (section.has_value()) - attributes.push_back(rewriter.getNamedAttr( - "section", rewriter.getStringAttr(section.value()))); - - attributes.push_back(rewriter.getNamedAttr("visibility_", visibility)); - - // Check for missing funcionalities. - if (!init.has_value()) { - rewriter.replaceOpWithNewOp( - op, llvmType, isConst, linkage, symbol, mlir::Attribute(), - /*alignment*/ 0, - /*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op), - /*dsoLocal*/ isDsoLocal, /*threadLocal*/ (bool)op.getTlsModelAttr(), - /*comdat*/ mlir::SymbolRefAttr(), attributes); - return mlir::success(); - } - - // Initializer is a constant array: convert it to a compatible llvm init. - if (auto constArr = mlir::dyn_cast(init.value())) { - if (auto attr = mlir::dyn_cast(constArr.getElts())) { - llvm::SmallString<256> literal(attr.getValue()); - if (constArr.getTrailingZerosNum()) - literal.append(constArr.getTrailingZerosNum(), '\0'); - init = rewriter.getStringAttr(literal); - } else if (auto attr = - mlir::dyn_cast(constArr.getElts())) { - // Failed to use a compact attribute as an initializer: - // initialize elements individually. - if (!(init = lowerConstArrayAttr(constArr, getTypeConverter()))) { - setupRegionInitializedLLVMGlobalOp(op, rewriter); - rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, constArr, rewriter, typeConverter)); - return mlir::success(); - } - } else { - op.emitError() - << "unsupported lowering for #cir.const_array with value " - << constArr.getElts(); - return mlir::failure(); + // Initializer is a constant array: convert it to a compatible llvm init. + if (auto constArr = mlir::dyn_cast(init.value())) { + if (auto attr = mlir::dyn_cast(constArr.getElts())) { + llvm::SmallString<256> literal(attr.getValue()); + if (constArr.getTrailingZerosNum()) + literal.append(constArr.getTrailingZerosNum(), '\0'); + init = rewriter.getStringAttr(literal); + } else if (auto attr = + mlir::dyn_cast(constArr.getElts())) { + // Failed to use a compact attribute as an initializer: + // initialize elements individually. + if (!(init = lowerConstArrayAttr(constArr, getTypeConverter()))) { + setupRegionInitializedLLVMGlobalOp(op, rewriter); + rewriter.create( + op->getLoc(), + lowerCirAttrAsValue(op, constArr, rewriter, typeConverter)); + return mlir::success(); } - } else if (auto fltAttr = mlir::dyn_cast(init.value())) { - // Initializer is a constant floating-point number: convert to MLIR - // builtin constant. - init = rewriter.getFloatAttr(llvmType, fltAttr.getValue()); - } - // Initializer is a constant integer: convert to MLIR builtin constant. - else if (auto intAttr = mlir::dyn_cast(init.value())) { - init = rewriter.getIntegerAttr(llvmType, intAttr.getValue()); - } else if (auto boolAttr = mlir::dyn_cast(init.value())) { - init = rewriter.getBoolAttr(boolAttr.getValue()); - } else if (isa( - init.value())) { - // TODO(cir): once LLVM's dialect has proper equivalent attributes this - // should be updated. For now, we use a custom op to initialize globals - // to the appropriate value. - setupRegionInitializedLLVMGlobalOp(op, rewriter); - auto value = - lowerCirAttrAsValue(op, init.value(), rewriter, typeConverter); - rewriter.create(loc, value); - return mlir::success(); - } else if (auto dataMemberAttr = - mlir::dyn_cast(init.value())) { - init = lowerDataMemberAttr(op->getParentOfType(), - dataMemberAttr, *typeConverter); - } else if (const auto structAttr = - mlir::dyn_cast(init.value())) { - setupRegionInitializedLLVMGlobalOp(op, rewriter); - rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter)); - return mlir::success(); - } else if (auto attr = mlir::dyn_cast(init.value())) { - setupRegionInitializedLLVMGlobalOp(op, rewriter); - rewriter.create( - loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter)); - return mlir::success(); - } else if (const auto vtableAttr = - mlir::dyn_cast(init.value())) { - setupRegionInitializedLLVMGlobalOp(op, rewriter); - rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, vtableAttr, rewriter, typeConverter)); - return mlir::success(); - } else if (const auto typeinfoAttr = - mlir::dyn_cast(init.value())) { - setupRegionInitializedLLVMGlobalOp(op, rewriter); - rewriter.create( - op->getLoc(), - lowerCirAttrAsValue(op, typeinfoAttr, rewriter, typeConverter)); - return mlir::success(); } else { - op.emitError() << "unsupported initializer '" << init.value() << "'"; + op.emitError() << "unsupported lowering for #cir.const_array with value " + << constArr.getElts(); return mlir::failure(); } + } else if (auto fltAttr = mlir::dyn_cast(init.value())) { + // Initializer is a constant floating-point number: convert to MLIR + // builtin constant. + init = rewriter.getFloatAttr(llvmType, fltAttr.getValue()); + } + // Initializer is a constant integer: convert to MLIR builtin constant. + else if (auto intAttr = mlir::dyn_cast(init.value())) { + init = rewriter.getIntegerAttr(llvmType, intAttr.getValue()); + } else if (auto boolAttr = mlir::dyn_cast(init.value())) { + init = rewriter.getBoolAttr(boolAttr.getValue()); + } else if (isa( + init.value())) { + // TODO(cir): once LLVM's dialect has proper equivalent attributes this + // should be updated. For now, we use a custom op to initialize globals + // to the appropriate value. + setupRegionInitializedLLVMGlobalOp(op, rewriter); + auto value = lowerCirAttrAsValue(op, init.value(), rewriter, typeConverter); + rewriter.create(loc, value); + return mlir::success(); + } else if (auto dataMemberAttr = + mlir::dyn_cast(init.value())) { + init = lowerDataMemberAttr(op->getParentOfType(), + dataMemberAttr, *typeConverter); + } else if (const auto structAttr = + mlir::dyn_cast(init.value())) { + setupRegionInitializedLLVMGlobalOp(op, rewriter); + rewriter.create( + op->getLoc(), + lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter)); + return mlir::success(); + } else if (auto attr = mlir::dyn_cast(init.value())) { + setupRegionInitializedLLVMGlobalOp(op, rewriter); + rewriter.create( + loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter)); + return mlir::success(); + } else if (const auto vtableAttr = + mlir::dyn_cast(init.value())) { + setupRegionInitializedLLVMGlobalOp(op, rewriter); + rewriter.create( + op->getLoc(), + lowerCirAttrAsValue(op, vtableAttr, rewriter, typeConverter)); + return mlir::success(); + } else if (const auto typeinfoAttr = + mlir::dyn_cast(init.value())) { + setupRegionInitializedLLVMGlobalOp(op, rewriter); + rewriter.create( + op->getLoc(), + lowerCirAttrAsValue(op, typeinfoAttr, rewriter, typeConverter)); + return mlir::success(); + } else { + op.emitError() << "unsupported initializer '" << init.value() << "'"; + return mlir::failure(); + } - // Rewrite op. - auto llvmGlobalOp = rewriter.replaceOpWithNewOp( - op, llvmType, isConst, linkage, symbol, init.value(), - /*alignment*/ op.getAlignment().value_or(0), - /*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op), - /*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(), - /*comdat*/ mlir::SymbolRefAttr(), attributes); + // Rewrite op. + auto llvmGlobalOp = rewriter.replaceOpWithNewOp( + op, llvmType, isConst, linkage, symbol, init.value(), + /*alignment*/ op.getAlignment().value_or(0), + /*addrSpace*/ getGlobalOpTargetAddrSpace(rewriter, typeConverter, op), + /*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(), + /*comdat*/ mlir::SymbolRefAttr(), attributes); - auto mod = op->getParentOfType(); - if (op.getComdat()) - addComdat(llvmGlobalOp, comdatOp, rewriter, mod); + auto mod = op->getParentOfType(); + if (op.getComdat()) + addComdat(llvmGlobalOp, comdatOp, rewriter, mod); - return mlir::success(); - } + return mlir::success(); +} -private: - mutable mlir::LLVM::ComdatOp comdatOp = nullptr; - static void addComdat(mlir::LLVM::GlobalOp &op, - mlir::LLVM::ComdatOp &comdatOp, - mlir::OpBuilder &builder, mlir::ModuleOp &module) { - StringRef comdatName("__llvm_comdat_globals"); - if (!comdatOp) { - builder.setInsertionPointToStart(module.getBody()); - comdatOp = - builder.create(module.getLoc(), comdatName); - } - builder.setInsertionPointToStart(&comdatOp.getBody().back()); - auto selectorOp = builder.create( - comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any); - op.setComdatAttr(mlir::SymbolRefAttr::get( - builder.getContext(), comdatName, - mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()))); - } -}; +void CIRToLLVMGlobalOpLowering::addComdat(mlir::LLVM::GlobalOp &op, + mlir::LLVM::ComdatOp &comdatOp, + mlir::OpBuilder &builder, + mlir::ModuleOp &module) { + StringRef comdatName("__llvm_comdat_globals"); + if (!comdatOp) { + builder.setInsertionPointToStart(module.getBody()); + comdatOp = + builder.create(module.getLoc(), comdatName); + } + builder.setInsertionPointToStart(&comdatOp.getBody().back()); + auto selectorOp = builder.create( + comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any); + op.setComdatAttr(mlir::SymbolRefAttr::get( + builder.getContext(), comdatName, + mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()))); +} -class CIRUnaryOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::UnaryOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - assert(op.getType() == op.getInput().getType() && - "Unary operation's operand type and result type are different"); - mlir::Type type = op.getType(); - mlir::Type elementType = elementTypeIfVector(type); - bool IsVector = mlir::isa(type); - auto llvmType = getTypeConverter()->convertType(type); - auto loc = op.getLoc(); - - // Integer unary operations: + - ~ ++ -- - if (mlir::isa(elementType)) { - switch (op.getKind()) { - case cir::UnaryOpKind::Inc: { - assert(!IsVector && "++ not allowed on vector types"); - auto One = rewriter.create( - loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); - rewriter.replaceOpWithNewOp(op, llvmType, - adaptor.getInput(), One); - return mlir::success(); - } - case cir::UnaryOpKind::Dec: { - assert(!IsVector && "-- not allowed on vector types"); - auto One = rewriter.create( - loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); - rewriter.replaceOpWithNewOp(op, llvmType, - adaptor.getInput(), One); - return mlir::success(); - } - case cir::UnaryOpKind::Plus: { - rewriter.replaceOp(op, adaptor.getInput()); - return mlir::success(); - } - case cir::UnaryOpKind::Minus: { - mlir::Value Zero; - if (IsVector) - Zero = rewriter.create(loc, llvmType); - else - Zero = rewriter.create( - loc, llvmType, mlir::IntegerAttr::get(llvmType, 0)); - rewriter.replaceOpWithNewOp(op, llvmType, Zero, - adaptor.getInput()); - return mlir::success(); - } - case cir::UnaryOpKind::Not: { - // bit-wise compliment operator, implemented as an XOR with -1. - mlir::Value MinusOne; - if (IsVector) { - // Creating a vector object with all -1 values is easier said than - // done. It requires a series of insertelement ops. - mlir::Type llvmElementType = - getTypeConverter()->convertType(elementType); - auto MinusOneInt = rewriter.create( - loc, llvmElementType, - mlir::IntegerAttr::get(llvmElementType, -1)); - MinusOne = rewriter.create(loc, llvmType); - auto NumElements = mlir::dyn_cast(type).getSize(); - for (uint64_t i = 0; i < NumElements; ++i) { - mlir::Value indexValue = rewriter.create( - loc, rewriter.getI64Type(), i); - MinusOne = rewriter.create( - loc, MinusOne, MinusOneInt, indexValue); - } - } else { - MinusOne = rewriter.create( - loc, llvmType, mlir::IntegerAttr::get(llvmType, -1)); +mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( + cir::UnaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + assert(op.getType() == op.getInput().getType() && + "Unary operation's operand type and result type are different"); + mlir::Type type = op.getType(); + mlir::Type elementType = elementTypeIfVector(type); + bool IsVector = mlir::isa(type); + auto llvmType = getTypeConverter()->convertType(type); + auto loc = op.getLoc(); + + // Integer unary operations: + - ~ ++ -- + if (mlir::isa(elementType)) { + switch (op.getKind()) { + case cir::UnaryOpKind::Inc: { + assert(!IsVector && "++ not allowed on vector types"); + auto One = rewriter.create( + loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); + rewriter.replaceOpWithNewOp(op, llvmType, + adaptor.getInput(), One); + return mlir::success(); + } + case cir::UnaryOpKind::Dec: { + assert(!IsVector && "-- not allowed on vector types"); + auto One = rewriter.create( + loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); + rewriter.replaceOpWithNewOp(op, llvmType, + adaptor.getInput(), One); + return mlir::success(); + } + case cir::UnaryOpKind::Plus: { + rewriter.replaceOp(op, adaptor.getInput()); + return mlir::success(); + } + case cir::UnaryOpKind::Minus: { + mlir::Value Zero; + if (IsVector) + Zero = rewriter.create(loc, llvmType); + else + Zero = rewriter.create( + loc, llvmType, mlir::IntegerAttr::get(llvmType, 0)); + rewriter.replaceOpWithNewOp(op, llvmType, Zero, + adaptor.getInput()); + return mlir::success(); + } + case cir::UnaryOpKind::Not: { + // bit-wise compliment operator, implemented as an XOR with -1. + mlir::Value MinusOne; + if (IsVector) { + // Creating a vector object with all -1 values is easier said than + // done. It requires a series of insertelement ops. + mlir::Type llvmElementType = + getTypeConverter()->convertType(elementType); + auto MinusOneInt = rewriter.create( + loc, llvmElementType, mlir::IntegerAttr::get(llvmElementType, -1)); + MinusOne = rewriter.create(loc, llvmType); + auto NumElements = mlir::dyn_cast(type).getSize(); + for (uint64_t i = 0; i < NumElements; ++i) { + mlir::Value indexValue = rewriter.create( + loc, rewriter.getI64Type(), i); + MinusOne = rewriter.create( + loc, MinusOne, MinusOneInt, indexValue); } - rewriter.replaceOpWithNewOp(op, llvmType, MinusOne, - adaptor.getInput()); - return mlir::success(); - } + } else { + MinusOne = rewriter.create( + loc, llvmType, mlir::IntegerAttr::get(llvmType, -1)); } + rewriter.replaceOpWithNewOp(op, llvmType, MinusOne, + adaptor.getInput()); + return mlir::success(); } - - // Floating point unary operations: + - ++ -- - if (mlir::isa(elementType)) { - switch (op.getKind()) { - case cir::UnaryOpKind::Inc: { - assert(!IsVector && "++ not allowed on vector types"); - auto oneAttr = rewriter.getFloatAttr(llvmType, 1.0); - auto oneConst = - rewriter.create(loc, llvmType, oneAttr); - rewriter.replaceOpWithNewOp(op, llvmType, oneConst, - adaptor.getInput()); - return mlir::success(); - } - case cir::UnaryOpKind::Dec: { - assert(!IsVector && "-- not allowed on vector types"); - auto negOneAttr = rewriter.getFloatAttr(llvmType, -1.0); - auto negOneConst = - rewriter.create(loc, llvmType, negOneAttr); - rewriter.replaceOpWithNewOp( - op, llvmType, negOneConst, adaptor.getInput()); - return mlir::success(); - } - case cir::UnaryOpKind::Plus: - rewriter.replaceOp(op, adaptor.getInput()); - return mlir::success(); - case cir::UnaryOpKind::Minus: { - rewriter.replaceOpWithNewOp(op, llvmType, - adaptor.getInput()); - return mlir::success(); - } - default: - return op.emitError() - << "Unknown floating-point unary operation during CIR lowering"; - } } + } - // Boolean unary operations: ! only. (For all others, the operand has - // already been promoted to int.) - if (mlir::isa(elementType)) { - switch (op.getKind()) { - case cir::UnaryOpKind::Not: - assert(!IsVector && "NYI: op! on vector mask"); - rewriter.replaceOpWithNewOp( - op, llvmType, adaptor.getInput(), - rewriter.create( - loc, llvmType, mlir::IntegerAttr::get(llvmType, 1))); - return mlir::success(); - default: - return op.emitError() - << "Unknown boolean unary operation during CIR lowering"; - } + // Floating point unary operations: + - ++ -- + if (mlir::isa(elementType)) { + switch (op.getKind()) { + case cir::UnaryOpKind::Inc: { + assert(!IsVector && "++ not allowed on vector types"); + auto oneAttr = rewriter.getFloatAttr(llvmType, 1.0); + auto oneConst = + rewriter.create(loc, llvmType, oneAttr); + rewriter.replaceOpWithNewOp(op, llvmType, oneConst, + adaptor.getInput()); + return mlir::success(); } - - // Pointer unary operations: + only. (++ and -- of pointers are implemented - // with cir.ptr_stride, not cir.unary.) - if (mlir::isa(elementType)) { - switch (op.getKind()) { - case cir::UnaryOpKind::Plus: - rewriter.replaceOp(op, adaptor.getInput()); - return mlir::success(); - default: - op.emitError() << "Unknown pointer unary operation during CIR lowering"; - return mlir::failure(); - } + case cir::UnaryOpKind::Dec: { + assert(!IsVector && "-- not allowed on vector types"); + auto negOneAttr = rewriter.getFloatAttr(llvmType, -1.0); + auto negOneConst = + rewriter.create(loc, llvmType, negOneAttr); + rewriter.replaceOpWithNewOp(op, llvmType, negOneConst, + adaptor.getInput()); + return mlir::success(); + } + case cir::UnaryOpKind::Plus: + rewriter.replaceOp(op, adaptor.getInput()); + return mlir::success(); + case cir::UnaryOpKind::Minus: { + rewriter.replaceOpWithNewOp(op, llvmType, + adaptor.getInput()); + return mlir::success(); + } + default: + return op.emitError() + << "Unknown floating-point unary operation during CIR lowering"; } - - return op.emitError() << "Unary operation has unsupported type: " - << elementType; } -}; - -class CIRBinOpLowering : public mlir::OpConversionPattern { - mlir::LLVM::IntegerOverflowFlags getIntOverflowFlag(cir::BinOp op) const { - if (op.getNoUnsignedWrap()) - return mlir::LLVM::IntegerOverflowFlags::nuw; - - if (op.getNoSignedWrap()) - return mlir::LLVM::IntegerOverflowFlags::nsw; + // Boolean unary operations: ! only. (For all others, the operand has + // already been promoted to int.) + if (mlir::isa(elementType)) { + switch (op.getKind()) { + case cir::UnaryOpKind::Not: + assert(!IsVector && "NYI: op! on vector mask"); + rewriter.replaceOpWithNewOp( + op, llvmType, adaptor.getInput(), + rewriter.create( + loc, llvmType, mlir::IntegerAttr::get(llvmType, 1))); + return mlir::success(); + default: + return op.emitError() + << "Unknown boolean unary operation during CIR lowering"; + } + } - return mlir::LLVM::IntegerOverflowFlags::none; + // Pointer unary operations: + only. (++ and -- of pointers are implemented + // with cir.ptr_stride, not cir.unary.) + if (mlir::isa(elementType)) { + switch (op.getKind()) { + case cir::UnaryOpKind::Plus: + rewriter.replaceOp(op, adaptor.getInput()); + return mlir::success(); + default: + op.emitError() << "Unknown pointer unary operation during CIR lowering"; + return mlir::failure(); + } } -public: - using OpConversionPattern::OpConversionPattern; + return op.emitError() << "Unary operation has unsupported type: " + << elementType; +} - mlir::LogicalResult - matchAndRewrite(cir::BinOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - assert((op.getLhs().getType() == op.getRhs().getType()) && - "inconsistent operands' types not supported yet"); - mlir::Type type = op.getRhs().getType(); - assert((mlir::isa( - type)) && - "operand type not supported yet"); +mlir::LLVM::IntegerOverflowFlags +CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const { + if (op.getNoUnsignedWrap()) + return mlir::LLVM::IntegerOverflowFlags::nuw; - auto llvmTy = getTypeConverter()->convertType(op.getType()); - auto rhs = adaptor.getRhs(); - auto lhs = adaptor.getLhs(); + if (op.getNoSignedWrap()) + return mlir::LLVM::IntegerOverflowFlags::nsw; - type = elementTypeIfVector(type); + return mlir::LLVM::IntegerOverflowFlags::none; +} - switch (op.getKind()) { - case cir::BinOpKind::Add: - if (mlir::isa(type)) - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs, - getIntOverflowFlag(op)); - else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::Sub: - if (mlir::isa(type)) - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs, - getIntOverflowFlag(op)); +mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite( + cir::BinOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + assert((op.getLhs().getType() == op.getRhs().getType()) && + "inconsistent operands' types not supported yet"); + mlir::Type type = op.getRhs().getType(); + assert((mlir::isa( + type)) && + "operand type not supported yet"); + + auto llvmTy = getTypeConverter()->convertType(op.getType()); + auto rhs = adaptor.getRhs(); + auto lhs = adaptor.getLhs(); + + type = elementTypeIfVector(type); + + switch (op.getKind()) { + case cir::BinOpKind::Add: + if (mlir::isa(type)) + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs, + getIntOverflowFlag(op)); + else + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::Sub: + if (mlir::isa(type)) + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs, + getIntOverflowFlag(op)); + else + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::Mul: + if (mlir::isa(type)) + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs, + getIntOverflowFlag(op)); + else + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::Div: + if (auto ty = mlir::dyn_cast(type)) { + if (ty.isUnsigned()) + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::Mul: - if (mlir::isa(type)) - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs, - getIntOverflowFlag(op)); + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + } else + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::Rem: + if (auto ty = mlir::dyn_cast(type)) { + if (ty.isUnsigned()) + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::Div: - if (auto ty = mlir::dyn_cast(type)) { - if (ty.isUnsigned()) - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - } else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::Rem: - if (auto ty = mlir::dyn_cast(type)) { - if (ty.isUnsigned()) - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - } else - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::And: - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::Or: - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - case cir::BinOpKind::Xor: - rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); - break; - } - - return mlir::LogicalResult::success(); - } -}; + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + } else + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::And: + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::Or: + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + case cir::BinOpKind::Xor: + rewriter.replaceOpWithNewOp(op, llvmTy, lhs, rhs); + break; + } + + return mlir::LogicalResult::success(); +} -class CIRBinOpOverflowOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BinOpOverflowOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto arithKind = op.getKind(); - auto operandTy = op.getLhs().getType(); - auto resultTy = op.getResult().getType(); - - auto encompassedTyInfo = computeEncompassedTypeWidth(operandTy, resultTy); - auto encompassedLLVMTy = rewriter.getIntegerType(encompassedTyInfo.width); - - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - if (operandTy.getWidth() < encompassedTyInfo.width) { - if (operandTy.isSigned()) { - lhs = rewriter.create(loc, encompassedLLVMTy, lhs); - rhs = rewriter.create(loc, encompassedLLVMTy, rhs); - } else { - lhs = rewriter.create(loc, encompassedLLVMTy, lhs); - rhs = rewriter.create(loc, encompassedLLVMTy, rhs); - } +mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite( + cir::BinOpOverflowOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto arithKind = op.getKind(); + auto operandTy = op.getLhs().getType(); + auto resultTy = op.getResult().getType(); + + auto encompassedTyInfo = computeEncompassedTypeWidth(operandTy, resultTy); + auto encompassedLLVMTy = rewriter.getIntegerType(encompassedTyInfo.width); + + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + if (operandTy.getWidth() < encompassedTyInfo.width) { + if (operandTy.isSigned()) { + lhs = rewriter.create(loc, encompassedLLVMTy, lhs); + rhs = rewriter.create(loc, encompassedLLVMTy, rhs); + } else { + lhs = rewriter.create(loc, encompassedLLVMTy, lhs); + rhs = rewriter.create(loc, encompassedLLVMTy, rhs); } + } - auto intrinName = getLLVMIntrinName(arithKind, encompassedTyInfo.sign, - encompassedTyInfo.width); - auto intrinNameAttr = mlir::StringAttr::get(op.getContext(), intrinName); + auto intrinName = getLLVMIntrinName(arithKind, encompassedTyInfo.sign, + encompassedTyInfo.width); + auto intrinNameAttr = mlir::StringAttr::get(op.getContext(), intrinName); - auto overflowLLVMTy = rewriter.getI1Type(); - auto intrinRetTy = mlir::LLVM::LLVMStructType::getLiteral( - rewriter.getContext(), {encompassedLLVMTy, overflowLLVMTy}); + auto overflowLLVMTy = rewriter.getI1Type(); + auto intrinRetTy = mlir::LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {encompassedLLVMTy, overflowLLVMTy}); - auto callLLVMIntrinOp = rewriter.create( - loc, intrinRetTy, intrinNameAttr, mlir::ValueRange{lhs, rhs}); - auto intrinRet = callLLVMIntrinOp.getResult(0); + auto callLLVMIntrinOp = rewriter.create( + loc, intrinRetTy, intrinNameAttr, mlir::ValueRange{lhs, rhs}); + auto intrinRet = callLLVMIntrinOp.getResult(0); - auto result = rewriter + auto result = rewriter + .create(loc, intrinRet, + ArrayRef{0}) + .getResult(); + auto overflow = rewriter .create(loc, intrinRet, - ArrayRef{0}) + ArrayRef{1}) .getResult(); - auto overflow = rewriter - .create( - loc, intrinRet, ArrayRef{1}) - .getResult(); - - if (resultTy.getWidth() < encompassedTyInfo.width) { - auto resultLLVMTy = getTypeConverter()->convertType(resultTy); - auto truncResult = - rewriter.create(loc, resultLLVMTy, result); - - // Extend the truncated result back to the encompassing type to check for - // any overflows during the truncation. - mlir::Value truncResultExt; - if (resultTy.isSigned()) - truncResultExt = rewriter.create( - loc, encompassedLLVMTy, truncResult); - else - truncResultExt = rewriter.create( - loc, encompassedLLVMTy, truncResult); - auto truncOverflow = rewriter.create( - loc, mlir::LLVM::ICmpPredicate::ne, truncResultExt, result); - - result = truncResult; - overflow = - rewriter.create(loc, overflow, truncOverflow); - } - auto boolLLVMTy = - getTypeConverter()->convertType(op.getOverflow().getType()); - if (boolLLVMTy != rewriter.getI1Type()) - overflow = rewriter.create(loc, boolLLVMTy, overflow); - - rewriter.replaceOp(op, mlir::ValueRange{result, overflow}); + if (resultTy.getWidth() < encompassedTyInfo.width) { + auto resultLLVMTy = getTypeConverter()->convertType(resultTy); + auto truncResult = + rewriter.create(loc, resultLLVMTy, result); + + // Extend the truncated result back to the encompassing type to check for + // any overflows during the truncation. + mlir::Value truncResultExt; + if (resultTy.isSigned()) + truncResultExt = rewriter.create( + loc, encompassedLLVMTy, truncResult); + else + truncResultExt = rewriter.create( + loc, encompassedLLVMTy, truncResult); + auto truncOverflow = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::ne, truncResultExt, result); - return mlir::success(); + result = truncResult; + overflow = rewriter.create(loc, overflow, truncOverflow); } -private: - static std::string getLLVMIntrinName(cir::BinOpOverflowKind opKind, - bool isSigned, unsigned width) { - // The intrinsic name is `@llvm.{s|u}{opKind}.with.overflow.i{width}` + auto boolLLVMTy = getTypeConverter()->convertType(op.getOverflow().getType()); + if (boolLLVMTy != rewriter.getI1Type()) + overflow = rewriter.create(loc, boolLLVMTy, overflow); - std::string name = "llvm."; + rewriter.replaceOp(op, mlir::ValueRange{result, overflow}); - if (isSigned) - name.push_back('s'); - else - name.push_back('u'); + return mlir::success(); +} - switch (opKind) { - case cir::BinOpOverflowKind::Add: - name.append("add."); - break; - case cir::BinOpOverflowKind::Sub: - name.append("sub."); - break; - case cir::BinOpOverflowKind::Mul: - name.append("mul."); - break; - } +std::string CIRToLLVMBinOpOverflowOpLowering::getLLVMIntrinName( + cir::BinOpOverflowKind opKind, bool isSigned, unsigned width) { + // The intrinsic name is `@llvm.{s|u}{opKind}.with.overflow.i{width}` - name.append("with.overflow.i"); - name.append(std::to_string(width)); + std::string name = "llvm."; - return name; + if (isSigned) + name.push_back('s'); + else + name.push_back('u'); + + switch (opKind) { + case cir::BinOpOverflowKind::Add: + name.append("add."); + break; + case cir::BinOpOverflowKind::Sub: + name.append("sub."); + break; + case cir::BinOpOverflowKind::Mul: + name.append("mul."); + break; } - struct EncompassedTypeInfo { - bool sign; - unsigned width; - }; + name.append("with.overflow.i"); + name.append(std::to_string(width)); - static EncompassedTypeInfo - computeEncompassedTypeWidth(cir::IntType operandTy, cir::IntType resultTy) { - auto sign = operandTy.getIsSigned() || resultTy.getIsSigned(); - auto width = - std::max(operandTy.getWidth() + (sign && operandTy.isUnsigned()), - resultTy.getWidth() + (sign && resultTy.isUnsigned())); - return {sign, width}; - } -}; + return name; +} -class CIRShiftOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ShiftOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto cirAmtTy = mlir::dyn_cast(op.getAmount().getType()); - auto cirValTy = mlir::dyn_cast(op.getValue().getType()); - - // Operands could also be vector type - auto cirAmtVTy = mlir::dyn_cast(op.getAmount().getType()); - auto cirValVTy = mlir::dyn_cast(op.getValue().getType()); - auto llvmTy = getTypeConverter()->convertType(op.getType()); - mlir::Value amt = adaptor.getAmount(); - mlir::Value val = adaptor.getValue(); - - assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) && - "shift input type must be integer or vector type, otherwise NYI"); - - assert((cirValTy == op.getType() || cirValVTy == op.getType()) && - "inconsistent operands' types NYI"); - - // Ensure shift amount is the same type as the value. Some undefined - // behavior might occur in the casts below as per [C99 6.5.7.3]. - // Vector type shift amount needs no cast as type consistency is expected to - // be already be enforced at CIRGen. - if (cirAmtTy) - amt = getLLVMIntCast(rewriter, amt, mlir::cast(llvmTy), - !cirAmtTy.isSigned(), cirAmtTy.getWidth(), - cirValTy.getWidth()); - - // Lower to the proper LLVM shift operation. - if (op.getIsShiftleft()) - rewriter.replaceOpWithNewOp(op, llvmTy, val, amt); - else { - bool isUnSigned = - cirValTy - ? !cirValTy.isSigned() - : !mlir::cast(cirValVTy.getEltType()).isSigned(); - if (isUnSigned) - rewriter.replaceOpWithNewOp(op, llvmTy, val, amt); - else - rewriter.replaceOpWithNewOp(op, llvmTy, val, amt); - } +CIRToLLVMBinOpOverflowOpLowering::EncompassedTypeInfo +CIRToLLVMBinOpOverflowOpLowering::computeEncompassedTypeWidth( + cir::IntType operandTy, cir::IntType resultTy) { + auto sign = operandTy.getIsSigned() || resultTy.getIsSigned(); + auto width = std::max(operandTy.getWidth() + (sign && operandTy.isUnsigned()), + resultTy.getWidth() + (sign && resultTy.isUnsigned())); + return {sign, width}; +} - return mlir::success(); +mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite( + cir::ShiftOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto cirAmtTy = mlir::dyn_cast(op.getAmount().getType()); + auto cirValTy = mlir::dyn_cast(op.getValue().getType()); + + // Operands could also be vector type + auto cirAmtVTy = mlir::dyn_cast(op.getAmount().getType()); + auto cirValVTy = mlir::dyn_cast(op.getValue().getType()); + auto llvmTy = getTypeConverter()->convertType(op.getType()); + mlir::Value amt = adaptor.getAmount(); + mlir::Value val = adaptor.getValue(); + + assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) && + "shift input type must be integer or vector type, otherwise NYI"); + + assert((cirValTy == op.getType() || cirValVTy == op.getType()) && + "inconsistent operands' types NYI"); + + // Ensure shift amount is the same type as the value. Some undefined + // behavior might occur in the casts below as per [C99 6.5.7.3]. + // Vector type shift amount needs no cast as type consistency is expected to + // be already be enforced at CIRGen. + if (cirAmtTy) + amt = getLLVMIntCast(rewriter, amt, mlir::cast(llvmTy), + !cirAmtTy.isSigned(), cirAmtTy.getWidth(), + cirValTy.getWidth()); + + // Lower to the proper LLVM shift operation. + if (op.getIsShiftleft()) + rewriter.replaceOpWithNewOp(op, llvmTy, val, amt); + else { + bool isUnSigned = + cirValTy ? !cirValTy.isSigned() + : !mlir::cast(cirValVTy.getEltType()).isSigned(); + if (isUnSigned) + rewriter.replaceOpWithNewOp(op, llvmTy, val, amt); + else + rewriter.replaceOpWithNewOp(op, llvmTy, val, amt); } -}; - -class CIRCmpOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::CmpOp cmpOp, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto type = cmpOp.getLhs().getType(); - mlir::Value llResult; - - // Lower to LLVM comparison op. - if (auto intTy = mlir::dyn_cast(type)) { - auto kind = - convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned()); - llResult = rewriter.create( - cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); - } else if (auto ptrTy = mlir::dyn_cast(type)) { - auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), - /* isSigned=*/false); - llResult = rewriter.create( - cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); - } else if (mlir::isa(type)) { - auto kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); - llResult = rewriter.create( - cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); - } else { - return cmpOp.emitError() << "unsupported type for CmpOp: " << type; - } - // LLVM comparison ops return i1, but cir::CmpOp returns the same type as - // the LHS value. Since this return value can be used later, we need to - // restore the type with the extension below. - auto llResultTy = getTypeConverter()->convertType(cmpOp.getType()); - rewriter.replaceOpWithNewOp(cmpOp, llResultTy, - llResult); + return mlir::success(); +} - return mlir::success(); +mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( + cir::CmpOp cmpOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto type = cmpOp.getLhs().getType(); + mlir::Value llResult; + + // Lower to LLVM comparison op. + if (auto intTy = mlir::dyn_cast(type)) { + auto kind = + convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned()); + llResult = rewriter.create( + cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + } else if (auto ptrTy = mlir::dyn_cast(type)) { + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + llResult = rewriter.create( + cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + } else if (mlir::isa(type)) { + auto kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); + llResult = rewriter.create( + cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + } else { + return cmpOp.emitError() << "unsupported type for CmpOp: " << type; } -}; -static mlir::LLVM::CallIntrinsicOp + // LLVM comparison ops return i1, but cir::CmpOp returns the same type as + // the LHS value. Since this return value can be used later, we need to + // restore the type with the extension below. + auto llResultTy = getTypeConverter()->convertType(cmpOp.getType()); + rewriter.replaceOpWithNewOp(cmpOp, llResultTy, llResult); + + return mlir::success(); +} + +mlir::LLVM::CallIntrinsicOp createCallLLVMIntrinsicOp(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, const llvm::Twine &intrinsicName, mlir::Type resultTy, mlir::ValueRange operands) { @@ -2987,7 +2686,7 @@ createCallLLVMIntrinsicOp(mlir::ConversionPatternRewriter &rewriter, loc, resultTy, intrinsicNameAttr, operands); } -static mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp( +mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp( mlir::ConversionPatternRewriter &rewriter, mlir::Operation *op, const llvm::Twine &intrinsicName, mlir::Type resultTy, mlir::ValueRange operands) { @@ -2997,99 +2696,76 @@ static mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp( return callIntrinOp; } -class CIRIntrinsicCallLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::LLVMIntrinsicCallOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type llvmResTy = - getTypeConverter()->convertType(op->getResultTypes()[0]); - if (!llvmResTy) - return op.emitError("expected LLVM result type"); - StringRef name = op.getIntrinsicName(); - // Some llvm intrinsics require ElementType attribute to be attached to - // the argument of pointer type. That prevents us from generating LLVM IR - // because from LLVM dialect, we have LLVM IR like the below which fails - // LLVM IR verification. - // %3 = call i64 @llvm.aarch64.ldxr.p0(ptr %2) - // The expected LLVM IR should be like - // %3 = call i64 @llvm.aarch64.ldxr.p0(ptr elementtype(i32) %2) - // TODO(cir): MLIR LLVM dialect should handle this part as CIR has no way - // to set LLVM IR attribute. - assert(!cir::MissingFeatures::llvmIntrinsicElementTypeSupport()); - replaceOpWithCallLLVMIntrinsicOp(rewriter, op, "llvm." + name, llvmResTy, - adaptor.getOperands()); - return mlir::success(); - } -}; - -class CIRAssumeLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::AssumeOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto cond = rewriter.create( - op.getLoc(), rewriter.getI1Type(), adaptor.getPredicate()); - rewriter.replaceOpWithNewOp(op, cond); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMLLVMIntrinsicCallOpLowering::matchAndRewrite( + cir::LLVMIntrinsicCallOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type llvmResTy = + getTypeConverter()->convertType(op->getResultTypes()[0]); + if (!llvmResTy) + return op.emitError("expected LLVM result type"); + StringRef name = op.getIntrinsicName(); + // Some llvm intrinsics require ElementType attribute to be attached to + // the argument of pointer type. That prevents us from generating LLVM IR + // because from LLVM dialect, we have LLVM IR like the below which fails + // LLVM IR verification. + // %3 = call i64 @llvm.aarch64.ldxr.p0(ptr %2) + // The expected LLVM IR should be like + // %3 = call i64 @llvm.aarch64.ldxr.p0(ptr elementtype(i32) %2) + // TODO(cir): MLIR LLVM dialect should handle this part as CIR has no way + // to set LLVM IR attribute. + assert(!cir::MissingFeatures::llvmIntrinsicElementTypeSupport()); + replaceOpWithCallLLVMIntrinsicOp(rewriter, op, "llvm." + name, llvmResTy, + adaptor.getOperands()); + return mlir::success(); +} -class CIRAssumeAlignedLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite( + cir::AssumeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto cond = rewriter.create( + op.getLoc(), rewriter.getI1Type(), adaptor.getPredicate()); + rewriter.replaceOpWithNewOp(op, cond); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::AssumeAlignedOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - SmallVector opBundleArgs{adaptor.getPointer()}; +mlir::LogicalResult CIRToLLVMAssumeAlignedOpLowering::matchAndRewrite( + cir::AssumeAlignedOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + SmallVector opBundleArgs{adaptor.getPointer()}; - auto alignment = rewriter.create( - op.getLoc(), rewriter.getI64Type(), op.getAlignment()); - opBundleArgs.push_back(alignment); + auto alignment = rewriter.create( + op.getLoc(), rewriter.getI64Type(), op.getAlignment()); + opBundleArgs.push_back(alignment); - if (mlir::Value offset = adaptor.getOffset()) - opBundleArgs.push_back(offset); + if (mlir::Value offset = adaptor.getOffset()) + opBundleArgs.push_back(offset); - auto cond = rewriter.create( - op.getLoc(), rewriter.getI1Type(), 1); - rewriter.create(op.getLoc(), cond, "align", - opBundleArgs); - rewriter.replaceAllUsesWith(op, op.getPointer()); - rewriter.eraseOp(op); + auto cond = rewriter.create(op.getLoc(), + rewriter.getI1Type(), 1); + rewriter.create(op.getLoc(), cond, "align", + opBundleArgs); + rewriter.replaceAllUsesWith(op, op.getPointer()); + rewriter.eraseOp(op); - return mlir::success(); - } -}; + return mlir::success(); +} -class CIRAssumeSepStorageLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::AssumeSepStorageOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto cond = rewriter.create( - op.getLoc(), rewriter.getI1Type(), 1); - rewriter.replaceOpWithNewOp( - op, cond, "separate_storage", - mlir::ValueRange{adaptor.getPtr1(), adaptor.getPtr2()}); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMAssumeSepStorageOpLowering::matchAndRewrite( + cir::AssumeSepStorageOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto cond = rewriter.create(op.getLoc(), + rewriter.getI1Type(), 1); + rewriter.replaceOpWithNewOp( + op, cond, "separate_storage", + mlir::ValueRange{adaptor.getPtr1(), adaptor.getPtr2()}); + return mlir::success(); +} -static mlir::Value createLLVMBitOp(mlir::Location loc, - const llvm::Twine &llvmIntrinBaseName, - mlir::Type resultTy, mlir::Value operand, - std::optional poisonZeroInputFlag, - mlir::ConversionPatternRewriter &rewriter) { +mlir::Value createLLVMBitOp(mlir::Location loc, + const llvm::Twine &llvmIntrinBaseName, + mlir::Type resultTy, mlir::Value operand, + std::optional poisonZeroInputFlag, + mlir::ConversionPatternRewriter &rewriter) { auto operandIntTy = mlir::cast(operand.getType()); auto resultIntTy = mlir::cast(resultTy); @@ -3117,1275 +2793,1048 @@ static mlir::Value createLLVMBitOp(mlir::Location loc, /*isUnsigned=*/true, operandIntTy.getWidth(), resultIntTy.getWidth()); } -class CIRBitClrsbOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BitClrsbOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto zero = rewriter.create( - op.getLoc(), adaptor.getInput().getType(), 0); - auto isNeg = rewriter.create( - op.getLoc(), - mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), - mlir::LLVM::ICmpPredicate::slt), - adaptor.getInput(), zero); - - auto negOne = rewriter.create( - op.getLoc(), adaptor.getInput().getType(), -1); - auto flipped = rewriter.create( - op.getLoc(), adaptor.getInput(), negOne); - - auto select = rewriter.create( - op.getLoc(), isNeg, flipped, adaptor.getInput()); - - auto resTy = getTypeConverter()->convertType(op.getType()); - auto clz = createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, select, - /*poisonZeroInputFlag=*/false, rewriter); - - auto one = rewriter.create(op.getLoc(), resTy, 1); - auto res = rewriter.create(op.getLoc(), clz, one); - rewriter.replaceOp(op, res); - - return mlir::LogicalResult::success(); - } -}; - -class CIRObjSizeOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ObjSizeOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto llvmResTy = getTypeConverter()->convertType(op.getType()); - auto loc = op->getLoc(); - - cir::SizeInfoType kindInfo = op.getKind(); - auto falseValue = rewriter.create( - loc, rewriter.getI1Type(), false); - auto trueValue = rewriter.create( - loc, rewriter.getI1Type(), true); - - replaceOpWithCallLLVMIntrinsicOp( - rewriter, op, "llvm.objectsize", llvmResTy, - mlir::ValueRange{adaptor.getPtr(), - kindInfo == cir::SizeInfoType::max ? falseValue - : trueValue, - trueValue, op.getDynamic() ? trueValue : falseValue}); - - return mlir::LogicalResult::success(); - } -}; - -class CIRBitClzOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BitClzOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resTy = getTypeConverter()->convertType(op.getType()); - auto llvmOp = - createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, adaptor.getInput(), - /*poisonZeroInputFlag=*/true, rewriter); - rewriter.replaceOp(op, llvmOp); - return mlir::LogicalResult::success(); - } -}; - -class CIRBitCtzOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BitCtzOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resTy = getTypeConverter()->convertType(op.getType()); - auto llvmOp = - createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(), - /*poisonZeroInputFlag=*/true, rewriter); - rewriter.replaceOp(op, llvmOp); - return mlir::LogicalResult::success(); - } -}; +mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite( + cir::BitClrsbOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto zero = rewriter.create( + op.getLoc(), adaptor.getInput().getType(), 0); + auto isNeg = rewriter.create( + op.getLoc(), + mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), + mlir::LLVM::ICmpPredicate::slt), + adaptor.getInput(), zero); + + auto negOne = rewriter.create( + op.getLoc(), adaptor.getInput().getType(), -1); + auto flipped = rewriter.create(op.getLoc(), + adaptor.getInput(), negOne); + + auto select = rewriter.create( + op.getLoc(), isNeg, flipped, adaptor.getInput()); + + auto resTy = getTypeConverter()->convertType(op.getType()); + auto clz = createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, select, + /*poisonZeroInputFlag=*/false, rewriter); + + auto one = rewriter.create(op.getLoc(), resTy, 1); + auto res = rewriter.create(op.getLoc(), clz, one); + rewriter.replaceOp(op, res); + + return mlir::LogicalResult::success(); +} -class CIRBitFfsOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMObjSizeOpLowering::matchAndRewrite( + cir::ObjSizeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto llvmResTy = getTypeConverter()->convertType(op.getType()); + auto loc = op->getLoc(); + + cir::SizeInfoType kindInfo = op.getKind(); + auto falseValue = + rewriter.create(loc, rewriter.getI1Type(), false); + auto trueValue = + rewriter.create(loc, rewriter.getI1Type(), true); + + replaceOpWithCallLLVMIntrinsicOp( + rewriter, op, "llvm.objectsize", llvmResTy, + mlir::ValueRange{adaptor.getPtr(), + kindInfo == cir::SizeInfoType::max ? falseValue + : trueValue, + trueValue, op.getDynamic() ? trueValue : falseValue}); + + return mlir::LogicalResult::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::BitFfsOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resTy = getTypeConverter()->convertType(op.getType()); - auto ctz = - createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(), - /*poisonZeroInputFlag=*/false, rewriter); +mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite( + cir::BitClzOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = getTypeConverter()->convertType(op.getType()); + auto llvmOp = + createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, adaptor.getInput(), + /*poisonZeroInputFlag=*/true, rewriter); + rewriter.replaceOp(op, llvmOp); + return mlir::LogicalResult::success(); +} - auto one = rewriter.create(op.getLoc(), resTy, 1); - auto ctzAddOne = rewriter.create(op.getLoc(), ctz, one); +mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite( + cir::BitCtzOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = getTypeConverter()->convertType(op.getType()); + auto llvmOp = + createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(), + /*poisonZeroInputFlag=*/true, rewriter); + rewriter.replaceOp(op, llvmOp); + return mlir::LogicalResult::success(); +} - auto zeroInputTy = rewriter.create( - op.getLoc(), adaptor.getInput().getType(), 0); - auto isZero = rewriter.create( - op.getLoc(), - mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), - mlir::LLVM::ICmpPredicate::eq), - adaptor.getInput(), zeroInputTy); +mlir::LogicalResult CIRToLLVMBitFfsOpLowering::matchAndRewrite( + cir::BitFfsOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = getTypeConverter()->convertType(op.getType()); + auto ctz = + createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(), + /*poisonZeroInputFlag=*/false, rewriter); + + auto one = rewriter.create(op.getLoc(), resTy, 1); + auto ctzAddOne = rewriter.create(op.getLoc(), ctz, one); + + auto zeroInputTy = rewriter.create( + op.getLoc(), adaptor.getInput().getType(), 0); + auto isZero = rewriter.create( + op.getLoc(), + mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), + mlir::LLVM::ICmpPredicate::eq), + adaptor.getInput(), zeroInputTy); + + auto zero = rewriter.create(op.getLoc(), resTy, 0); + auto res = rewriter.create(op.getLoc(), isZero, zero, + ctzAddOne); + rewriter.replaceOp(op, res); + + return mlir::LogicalResult::success(); +} - auto zero = rewriter.create(op.getLoc(), resTy, 0); - auto res = rewriter.create(op.getLoc(), isZero, zero, - ctzAddOne); - rewriter.replaceOp(op, res); +mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite( + cir::BitParityOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = getTypeConverter()->convertType(op.getType()); + auto popcnt = + createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(), + /*poisonZeroInputFlag=*/std::nullopt, rewriter); - return mlir::LogicalResult::success(); - } -}; + auto one = rewriter.create(op.getLoc(), resTy, 1); + auto popcntMod2 = + rewriter.create(op.getLoc(), popcnt, one); + rewriter.replaceOp(op, popcntMod2); -class CIRBitParityOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BitParityOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resTy = getTypeConverter()->convertType(op.getType()); - auto popcnt = - createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(), - /*poisonZeroInputFlag=*/std::nullopt, rewriter); - - auto one = rewriter.create(op.getLoc(), resTy, 1); - auto popcntMod2 = - rewriter.create(op.getLoc(), popcnt, one); - rewriter.replaceOp(op, popcntMod2); - - return mlir::LogicalResult::success(); - } -}; + return mlir::LogicalResult::success(); +} -class CIRBitPopcountOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::BitPopcountOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resTy = getTypeConverter()->convertType(op.getType()); - auto llvmOp = - createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(), - /*poisonZeroInputFlag=*/std::nullopt, rewriter); - rewriter.replaceOp(op, llvmOp); - return mlir::LogicalResult::success(); - } -}; +mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite( + cir::BitPopcountOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = getTypeConverter()->convertType(op.getType()); + auto llvmOp = + createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(), + /*poisonZeroInputFlag=*/std::nullopt, rewriter); + rewriter.replaceOp(op, llvmOp); + return mlir::LogicalResult::success(); +} -static mlir::LLVM::AtomicOrdering getLLVMAtomicOrder(cir::MemOrder memo) { +mlir::LLVM::AtomicOrdering getLLVMAtomicOrder(cir::MemOrder memo) { switch (memo) { case cir::MemOrder::Relaxed: return mlir::LLVM::AtomicOrdering::monotonic; - case cir::MemOrder::Consume: - case cir::MemOrder::Acquire: - return mlir::LLVM::AtomicOrdering::acquire; - case cir::MemOrder::Release: - return mlir::LLVM::AtomicOrdering::release; - case cir::MemOrder::AcquireRelease: - return mlir::LLVM::AtomicOrdering::acq_rel; - case cir::MemOrder::SequentiallyConsistent: - return mlir::LLVM::AtomicOrdering::seq_cst; - } - llvm_unreachable("shouldn't get here"); -} - -class CIRAtomicCmpXchgLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::AtomicCmpXchg op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto expected = adaptor.getExpected(); - auto desired = adaptor.getDesired(); - - // FIXME: add syncscope. - auto cmpxchg = rewriter.create( - op.getLoc(), adaptor.getPtr(), expected, desired, - getLLVMAtomicOrder(adaptor.getSuccOrder()), - getLLVMAtomicOrder(adaptor.getFailOrder())); - cmpxchg.setWeak(adaptor.getWeak()); - cmpxchg.setVolatile_(adaptor.getIsVolatile()); - - // Check result and apply stores accordingly. - auto old = rewriter.create( - op.getLoc(), cmpxchg.getResult(), 0); - auto cmp = rewriter.create( - op.getLoc(), cmpxchg.getResult(), 1); - - auto extCmp = rewriter.create( - op.getLoc(), rewriter.getI8Type(), cmp); - rewriter.replaceOp(op, {old, extCmp}); - return mlir::success(); - } -}; - -class CIRAtomicXchgLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::AtomicXchg op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // FIXME: add syncscope. - auto llvmOrder = getLLVMAtomicOrder(adaptor.getMemOrder()); - rewriter.replaceOpWithNewOp( - op, mlir::LLVM::AtomicBinOp::xchg, adaptor.getPtr(), adaptor.getVal(), - llvmOrder); - return mlir::success(); - } -}; - -class CIRAtomicFetchLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::Value buildPostOp(cir::AtomicFetch op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter, - mlir::Value rmwVal, bool isInt) const { - SmallVector atomicOperands = {rmwVal, adaptor.getVal()}; - SmallVector atomicResTys = {rmwVal.getType()}; - return rewriter - .create(op.getLoc(), - rewriter.getStringAttr(getLLVMBinop(op.getBinop(), isInt)), - atomicOperands, atomicResTys, {}) - ->getResult(0); - } - - mlir::Value buildMinMaxPostOp(cir::AtomicFetch op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter, - mlir::Value rmwVal, bool isSigned) const { - auto loc = op.getLoc(); - mlir::LLVM::ICmpPredicate pred; - if (op.getBinop() == cir::AtomicFetchKind::Max) { - pred = isSigned ? mlir::LLVM::ICmpPredicate::sgt - : mlir::LLVM::ICmpPredicate::ugt; - } else { // Min - pred = isSigned ? mlir::LLVM::ICmpPredicate::slt - : mlir::LLVM::ICmpPredicate::ult; - } - - auto cmp = rewriter.create( - loc, mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), pred), - rmwVal, adaptor.getVal()); - return rewriter.create(loc, cmp, rmwVal, - adaptor.getVal()); - } - - llvm::StringLiteral getLLVMBinop(cir::AtomicFetchKind k, bool isInt) const { - switch (k) { - case cir::AtomicFetchKind::Add: - return isInt ? mlir::LLVM::AddOp::getOperationName() - : mlir::LLVM::FAddOp::getOperationName(); - case cir::AtomicFetchKind::Sub: - return isInt ? mlir::LLVM::SubOp::getOperationName() - : mlir::LLVM::FSubOp::getOperationName(); - case cir::AtomicFetchKind::And: - return mlir::LLVM::AndOp::getOperationName(); - case cir::AtomicFetchKind::Xor: - return mlir::LLVM::XOrOp::getOperationName(); - case cir::AtomicFetchKind::Or: - return mlir::LLVM::OrOp::getOperationName(); - case cir::AtomicFetchKind::Nand: - // There's no nand binop in LLVM, this is later fixed with a not. - return mlir::LLVM::AndOp::getOperationName(); - case cir::AtomicFetchKind::Max: - case cir::AtomicFetchKind::Min: - llvm_unreachable("handled in buildMinMaxPostOp"); - } - llvm_unreachable("Unknown atomic fetch opcode"); - } - - mlir::LLVM::AtomicBinOp getLLVMAtomicBinOp(cir::AtomicFetchKind k, bool isInt, - bool isSignedInt) const { - switch (k) { - case cir::AtomicFetchKind::Add: - return isInt ? mlir::LLVM::AtomicBinOp::add - : mlir::LLVM::AtomicBinOp::fadd; - case cir::AtomicFetchKind::Sub: - return isInt ? mlir::LLVM::AtomicBinOp::sub - : mlir::LLVM::AtomicBinOp::fsub; - case cir::AtomicFetchKind::And: - return mlir::LLVM::AtomicBinOp::_and; - case cir::AtomicFetchKind::Xor: - return mlir::LLVM::AtomicBinOp::_xor; - case cir::AtomicFetchKind::Or: - return mlir::LLVM::AtomicBinOp::_or; - case cir::AtomicFetchKind::Nand: - return mlir::LLVM::AtomicBinOp::nand; - case cir::AtomicFetchKind::Max: { - if (!isInt) - return mlir::LLVM::AtomicBinOp::fmax; - return isSignedInt ? mlir::LLVM::AtomicBinOp::max - : mlir::LLVM::AtomicBinOp::umax; - } - case cir::AtomicFetchKind::Min: { - if (!isInt) - return mlir::LLVM::AtomicBinOp::fmin; - return isSignedInt ? mlir::LLVM::AtomicBinOp::min - : mlir::LLVM::AtomicBinOp::umin; - } - } - llvm_unreachable("Unknown atomic fetch opcode"); - } - - mlir::LogicalResult - matchAndRewrite(cir::AtomicFetch op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - - bool isInt, isSignedInt = false; // otherwise it's float. - if (auto intTy = mlir::dyn_cast(op.getVal().getType())) { - isInt = true; - isSignedInt = intTy.isSigned(); - } else if (mlir::isa( - op.getVal().getType())) - isInt = false; - else { - return op.emitError() - << "Unsupported type: " << adaptor.getVal().getType(); - } - - // FIXME: add syncscope. - auto llvmOrder = getLLVMAtomicOrder(adaptor.getMemOrder()); - auto llvmBinOpc = getLLVMAtomicBinOp(op.getBinop(), isInt, isSignedInt); - auto rmwVal = rewriter.create( - op.getLoc(), llvmBinOpc, adaptor.getPtr(), adaptor.getVal(), llvmOrder); - - mlir::Value result = rmwVal.getRes(); - if (!op.getFetchFirst()) { - if (op.getBinop() == cir::AtomicFetchKind::Max || - op.getBinop() == cir::AtomicFetchKind::Min) - result = buildMinMaxPostOp(op, adaptor, rewriter, rmwVal.getRes(), - isSignedInt); - else - result = buildPostOp(op, adaptor, rewriter, rmwVal.getRes(), isInt); - - // Compensate lack of nand binop in LLVM IR. - if (op.getBinop() == cir::AtomicFetchKind::Nand) { - auto negOne = rewriter.create( - op.getLoc(), result.getType(), -1); - result = - rewriter.create(op.getLoc(), result, negOne); - } - } - - rewriter.replaceOp(op, result); - return mlir::success(); + case cir::MemOrder::Consume: + case cir::MemOrder::Acquire: + return mlir::LLVM::AtomicOrdering::acquire; + case cir::MemOrder::Release: + return mlir::LLVM::AtomicOrdering::release; + case cir::MemOrder::AcquireRelease: + return mlir::LLVM::AtomicOrdering::acq_rel; + case cir::MemOrder::SequentiallyConsistent: + return mlir::LLVM::AtomicOrdering::seq_cst; } -}; + llvm_unreachable("shouldn't get here"); +} -class CIRByteswapOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMAtomicCmpXchgLowering::matchAndRewrite( + cir::AtomicCmpXchg op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto expected = adaptor.getExpected(); + auto desired = adaptor.getDesired(); + + // FIXME: add syncscope. + auto cmpxchg = rewriter.create( + op.getLoc(), adaptor.getPtr(), expected, desired, + getLLVMAtomicOrder(adaptor.getSuccOrder()), + getLLVMAtomicOrder(adaptor.getFailOrder())); + cmpxchg.setWeak(adaptor.getWeak()); + cmpxchg.setVolatile_(adaptor.getIsVolatile()); + + // Check result and apply stores accordingly. + auto old = rewriter.create( + op.getLoc(), cmpxchg.getResult(), 0); + auto cmp = rewriter.create( + op.getLoc(), cmpxchg.getResult(), 1); + + auto extCmp = rewriter.create(op.getLoc(), + rewriter.getI8Type(), cmp); + rewriter.replaceOp(op, {old, extCmp}); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::ByteswapOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Note that LLVM intrinsic calls to @llvm.bswap.i* have the same type as - // the operand. +mlir::LogicalResult CIRToLLVMAtomicXchgLowering::matchAndRewrite( + cir::AtomicXchg op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // FIXME: add syncscope. + auto llvmOrder = getLLVMAtomicOrder(adaptor.getMemOrder()); + rewriter.replaceOpWithNewOp( + op, mlir::LLVM::AtomicBinOp::xchg, adaptor.getPtr(), adaptor.getVal(), + llvmOrder); + return mlir::success(); +} - auto resTy = mlir::cast( - getTypeConverter()->convertType(op.getType())); +mlir::Value CIRToLLVMAtomicFetchLowering::buildPostOp( + cir::AtomicFetch op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, mlir::Value rmwVal, + bool isInt) const { + SmallVector atomicOperands = {rmwVal, adaptor.getVal()}; + SmallVector atomicResTys = {rmwVal.getType()}; + return rewriter + .create(op.getLoc(), + rewriter.getStringAttr(getLLVMBinop(op.getBinop(), isInt)), + atomicOperands, atomicResTys, {}) + ->getResult(0); +} - std::string llvmIntrinName = "llvm.bswap.i"; - llvmIntrinName.append(std::to_string(resTy.getWidth())); +mlir::Value CIRToLLVMAtomicFetchLowering::buildMinMaxPostOp( + cir::AtomicFetch op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, mlir::Value rmwVal, + bool isSigned) const { + auto loc = op.getLoc(); + mlir::LLVM::ICmpPredicate pred; + if (op.getBinop() == cir::AtomicFetchKind::Max) { + pred = isSigned ? mlir::LLVM::ICmpPredicate::sgt + : mlir::LLVM::ICmpPredicate::ugt; + } else { // Min + pred = isSigned ? mlir::LLVM::ICmpPredicate::slt + : mlir::LLVM::ICmpPredicate::ult; + } + + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), pred), + rmwVal, adaptor.getVal()); + return rewriter.create(loc, cmp, rmwVal, + adaptor.getVal()); +} - rewriter.replaceOpWithNewOp(op, adaptor.getInput()); +llvm::StringLiteral +CIRToLLVMAtomicFetchLowering::getLLVMBinop(cir::AtomicFetchKind k, + bool isInt) const { + switch (k) { + case cir::AtomicFetchKind::Add: + return isInt ? mlir::LLVM::AddOp::getOperationName() + : mlir::LLVM::FAddOp::getOperationName(); + case cir::AtomicFetchKind::Sub: + return isInt ? mlir::LLVM::SubOp::getOperationName() + : mlir::LLVM::FSubOp::getOperationName(); + case cir::AtomicFetchKind::And: + return mlir::LLVM::AndOp::getOperationName(); + case cir::AtomicFetchKind::Xor: + return mlir::LLVM::XOrOp::getOperationName(); + case cir::AtomicFetchKind::Or: + return mlir::LLVM::OrOp::getOperationName(); + case cir::AtomicFetchKind::Nand: + // There's no nand binop in LLVM, this is later fixed with a not. + return mlir::LLVM::AndOp::getOperationName(); + case cir::AtomicFetchKind::Max: + case cir::AtomicFetchKind::Min: + llvm_unreachable("handled in buildMinMaxPostOp"); + } + llvm_unreachable("Unknown atomic fetch opcode"); +} - return mlir::LogicalResult::success(); - } -}; +mlir::LLVM::AtomicBinOp CIRToLLVMAtomicFetchLowering::getLLVMAtomicBinOp( + cir::AtomicFetchKind k, bool isInt, bool isSignedInt) const { + switch (k) { + case cir::AtomicFetchKind::Add: + return isInt ? mlir::LLVM::AtomicBinOp::add : mlir::LLVM::AtomicBinOp::fadd; + case cir::AtomicFetchKind::Sub: + return isInt ? mlir::LLVM::AtomicBinOp::sub : mlir::LLVM::AtomicBinOp::fsub; + case cir::AtomicFetchKind::And: + return mlir::LLVM::AtomicBinOp::_and; + case cir::AtomicFetchKind::Xor: + return mlir::LLVM::AtomicBinOp::_xor; + case cir::AtomicFetchKind::Or: + return mlir::LLVM::AtomicBinOp::_or; + case cir::AtomicFetchKind::Nand: + return mlir::LLVM::AtomicBinOp::nand; + case cir::AtomicFetchKind::Max: { + if (!isInt) + return mlir::LLVM::AtomicBinOp::fmax; + return isSignedInt ? mlir::LLVM::AtomicBinOp::max + : mlir::LLVM::AtomicBinOp::umax; + } + case cir::AtomicFetchKind::Min: { + if (!isInt) + return mlir::LLVM::AtomicBinOp::fmin; + return isSignedInt ? mlir::LLVM::AtomicBinOp::min + : mlir::LLVM::AtomicBinOp::umin; + } + } + llvm_unreachable("Unknown atomic fetch opcode"); +} -class CIRRotateOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::RotateOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Note that LLVM intrinsic calls to @llvm.fsh{r,l}.i* have the same type as - // the operand. - auto src = adaptor.getSrc(); - if (op.getLeft()) - rewriter.replaceOpWithNewOp(op, src, src, - adaptor.getAmt()); +mlir::LogicalResult CIRToLLVMAtomicFetchLowering::matchAndRewrite( + cir::AtomicFetch op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + + bool isInt, isSignedInt = false; // otherwise it's float. + if (auto intTy = mlir::dyn_cast(op.getVal().getType())) { + isInt = true; + isSignedInt = intTy.isSigned(); + } else if (mlir::isa(op.getVal().getType())) + isInt = false; + else { + return op.emitError() << "Unsupported type: " << adaptor.getVal().getType(); + } + + // FIXME: add syncscope. + auto llvmOrder = getLLVMAtomicOrder(adaptor.getMemOrder()); + auto llvmBinOpc = getLLVMAtomicBinOp(op.getBinop(), isInt, isSignedInt); + auto rmwVal = rewriter.create( + op.getLoc(), llvmBinOpc, adaptor.getPtr(), adaptor.getVal(), llvmOrder); + + mlir::Value result = rmwVal.getRes(); + if (!op.getFetchFirst()) { + if (op.getBinop() == cir::AtomicFetchKind::Max || + op.getBinop() == cir::AtomicFetchKind::Min) + result = buildMinMaxPostOp(op, adaptor, rewriter, rmwVal.getRes(), + isSignedInt); else - rewriter.replaceOpWithNewOp(op, src, src, - adaptor.getAmt()); - return mlir::LogicalResult::success(); - } -}; + result = buildPostOp(op, adaptor, rewriter, rmwVal.getRes(), isInt); -class CIRSelectOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + // Compensate lack of nand binop in LLVM IR. + if (op.getBinop() == cir::AtomicFetchKind::Nand) { + auto negOne = rewriter.create( + op.getLoc(), result.getType(), -1); + result = rewriter.create(op.getLoc(), result, negOne); + } + } - mlir::LogicalResult - matchAndRewrite(cir::SelectOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto getConstantBool = [](mlir::Value value) -> std::optional { - auto definingOp = - mlir::dyn_cast_if_present(value.getDefiningOp()); - if (!definingOp) - return std::nullopt; + rewriter.replaceOp(op, result); + return mlir::success(); +} - auto constValue = mlir::dyn_cast(definingOp.getValue()); - if (!constValue) - return std::nullopt; +mlir::LogicalResult CIRToLLVMByteswapOpLowering::matchAndRewrite( + cir::ByteswapOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Note that LLVM intrinsic calls to @llvm.bswap.i* have the same type as + // the operand. - return constValue.getValue(); - }; + auto resTy = mlir::cast( + getTypeConverter()->convertType(op.getType())); - // Two special cases in the LLVMIR codegen of select op: - // - select %0, %1, false => and %0, %1 - // - select %0, true, %1 => or %0, %1 - auto trueValue = op.getTrueValue(); - auto falseValue = op.getFalseValue(); - if (mlir::isa(trueValue.getType())) { - if (std::optional falseValueBool = getConstantBool(falseValue); - falseValueBool.has_value() && !*falseValueBool) { - // select %0, %1, false => and %0, %1 - rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), adaptor.getTrueValue()); - return mlir::success(); - } - if (std::optional trueValueBool = getConstantBool(trueValue); - trueValueBool.has_value() && *trueValueBool) { - // select %0, true, %1 => or %0, %1 - rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), adaptor.getFalseValue()); - return mlir::success(); - } - } + std::string llvmIntrinName = "llvm.bswap.i"; + llvmIntrinName.append(std::to_string(resTy.getWidth())); - auto llvmCondition = rewriter.create( - op.getLoc(), mlir::IntegerType::get(op->getContext(), 1), - adaptor.getCondition()); - rewriter.replaceOpWithNewOp( - op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); + rewriter.replaceOpWithNewOp(op, adaptor.getInput()); - return mlir::success(); - } -}; + return mlir::LogicalResult::success(); +} -class CIRBrOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMRotateOpLowering::matchAndRewrite( + cir::RotateOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Note that LLVM intrinsic calls to @llvm.fsh{r,l}.i* have the same type as + // the operand. + auto src = adaptor.getSrc(); + if (op.getLeft()) + rewriter.replaceOpWithNewOp(op, src, src, + adaptor.getAmt()); + else + rewriter.replaceOpWithNewOp(op, src, src, + adaptor.getAmt()); + return mlir::LogicalResult::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::BrOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), - op.getDest()); - return mlir::LogicalResult::success(); - } -}; +mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite( + cir::SelectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto getConstantBool = [](mlir::Value value) -> std::optional { + auto definingOp = + mlir::dyn_cast_if_present(value.getDefiningOp()); + if (!definingOp) + return std::nullopt; -class CIRGetMemberOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; + auto constValue = mlir::dyn_cast(definingOp.getValue()); + if (!constValue) + return std::nullopt; - mlir::LogicalResult - matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto llResTy = getTypeConverter()->convertType(op.getType()); - const auto structTy = - mlir::cast(op.getAddrTy().getPointee()); - assert(structTy && "expected struct type"); + return constValue.getValue(); + }; - switch (structTy.getKind()) { - case cir::StructType::Struct: - case cir::StructType::Class: { - // Since the base address is a pointer to an aggregate, the first offset - // is always zero. The second offset tell us which member it will access. - llvm::SmallVector offset{0, op.getIndex()}; - const auto elementTy = getTypeConverter()->convertType(structTy); - rewriter.replaceOpWithNewOp(op, llResTy, elementTy, - adaptor.getAddr(), offset); + // Two special cases in the LLVMIR codegen of select op: + // - select %0, %1, false => and %0, %1 + // - select %0, true, %1 => or %0, %1 + auto trueValue = op.getTrueValue(); + auto falseValue = op.getFalseValue(); + if (mlir::isa(trueValue.getType())) { + if (std::optional falseValueBool = getConstantBool(falseValue); + falseValueBool.has_value() && !*falseValueBool) { + // select %0, %1, false => and %0, %1 + rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), + adaptor.getTrueValue()); return mlir::success(); } - case cir::StructType::Union: - // Union members share the address space, so we just need a bitcast to - // conform to type-checking. - rewriter.replaceOpWithNewOp(op, llResTy, - adaptor.getAddr()); + if (std::optional trueValueBool = getConstantBool(trueValue); + trueValueBool.has_value() && *trueValueBool) { + // select %0, true, %1 => or %0, %1 + rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), + adaptor.getFalseValue()); return mlir::success(); } } -}; - -class CIRGetRuntimeMemberOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::GetRuntimeMemberOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto llvmResTy = getTypeConverter()->convertType(op.getType()); - auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8); - - rewriter.replaceOpWithNewOp( - op, llvmResTy, llvmElementTy, adaptor.getAddr(), adaptor.getMember()); - return mlir::success(); - } -}; - -class CIRPtrDiffOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - uint64_t getTypeSize(mlir::Type type, mlir::Operation &op) const { - mlir::DataLayout layout(op.getParentOfType()); - // For LLVM purposes we treat void as u8. - if (isa(type)) - type = cir::IntType::get(type.getContext(), 8, /*isSigned=*/false); - return llvm::divideCeil(layout.getTypeSizeInBits(type), 8); - } - - mlir::LogicalResult - matchAndRewrite(cir::PtrDiffOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto dstTy = mlir::cast(op.getType()); - auto llvmDstTy = getTypeConverter()->convertType(dstTy); - auto lhs = rewriter.create(op.getLoc(), llvmDstTy, - adaptor.getLhs()); - auto rhs = rewriter.create(op.getLoc(), llvmDstTy, - adaptor.getRhs()); + auto llvmCondition = rewriter.create( + op.getLoc(), mlir::IntegerType::get(op->getContext(), 1), + adaptor.getCondition()); + rewriter.replaceOpWithNewOp( + op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); - auto diff = - rewriter.create(op.getLoc(), llvmDstTy, lhs, rhs); - - auto ptrTy = mlir::cast(op.getLhs().getType()); - auto typeSize = getTypeSize(ptrTy.getPointee(), *op); + return mlir::success(); +} - // Avoid silly division by 1. - auto resultVal = diff.getResult(); - if (typeSize != 1) { - auto typeSizeVal = rewriter.create( - op.getLoc(), llvmDstTy, mlir::IntegerAttr::get(llvmDstTy, typeSize)); +mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite( + cir::BrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), + op.getDest()); + return mlir::LogicalResult::success(); +} - if (dstTy.isUnsigned()) - resultVal = rewriter.create(op.getLoc(), llvmDstTy, - diff, typeSizeVal); - else - resultVal = rewriter.create(op.getLoc(), llvmDstTy, - diff, typeSizeVal); - } - rewriter.replaceOp(op, resultVal); +mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite( + cir::GetMemberOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto llResTy = getTypeConverter()->convertType(op.getType()); + const auto structTy = + mlir::cast(op.getAddrTy().getPointee()); + assert(structTy && "expected struct type"); + + switch (structTy.getKind()) { + case cir::StructType::Struct: + case cir::StructType::Class: { + // Since the base address is a pointer to an aggregate, the first offset + // is always zero. The second offset tell us which member it will access. + llvm::SmallVector offset{0, op.getIndex()}; + const auto elementTy = getTypeConverter()->convertType(structTy); + rewriter.replaceOpWithNewOp(op, llResTy, elementTy, + adaptor.getAddr(), offset); return mlir::success(); } -}; - -class CIRExpectOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ExpectOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - std::optional prob = op.getProb(); - if (!prob) - rewriter.replaceOpWithNewOp(op, adaptor.getVal(), - adaptor.getExpected()); - else - rewriter.replaceOpWithNewOp( - op, adaptor.getVal(), adaptor.getExpected(), prob.value()); + case cir::StructType::Union: + // Union members share the address space, so we just need a bitcast to + // conform to type-checking. + rewriter.replaceOpWithNewOp(op, llResTy, + adaptor.getAddr()); return mlir::success(); } -}; - -class CIRVTableAddrPointOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::VTableAddrPointOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - const auto *converter = getTypeConverter(); - auto targetType = converter->convertType(op.getType()); - mlir::Value symAddr = op.getSymAddr(); - llvm::SmallVector offsets; - mlir::Type eltType; - if (!symAddr) { - symAddr = getValueForVTableSymbol(op, rewriter, getTypeConverter(), - op.getNameAttr(), eltType); - offsets = llvm::SmallVector{ - 0, op.getVtableIndex(), op.getAddressPointIndex()}; - } else { - // Get indirect vtable address point retrieval - symAddr = adaptor.getSymAddr(); - eltType = converter->convertType(symAddr.getType()); - offsets = - llvm::SmallVector{op.getAddressPointIndex()}; - } - - assert(eltType && "Shouldn't ever be missing an eltType here"); - rewriter.replaceOpWithNewOp(op, targetType, eltType, - symAddr, offsets, true); +} - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMGetRuntimeMemberOpLowering::matchAndRewrite( + cir::GetRuntimeMemberOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto llvmResTy = getTypeConverter()->convertType(op.getType()); + auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8); -class CIRStackSaveLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + rewriter.replaceOpWithNewOp( + op, llvmResTy, llvmElementTy, adaptor.getAddr(), adaptor.getMember()); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::StackSaveOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto ptrTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, ptrTy); - return mlir::success(); - } -}; +uint64_t CIRToLLVMPtrDiffOpLowering::getTypeSize(mlir::Type type, + mlir::Operation &op) const { + mlir::DataLayout layout(op.getParentOfType()); + // For LLVM purposes we treat void as u8. + if (isa(type)) + type = cir::IntType::get(type.getContext(), 8, /*isSigned=*/false); + return llvm::divideCeil(layout.getTypeSizeInBits(type), 8); +} -#define GET_BUILTIN_LOWERING_CLASSES -#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" +mlir::LogicalResult CIRToLLVMPtrDiffOpLowering::matchAndRewrite( + cir::PtrDiffOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto dstTy = mlir::cast(op.getType()); + auto llvmDstTy = getTypeConverter()->convertType(dstTy); -class CIRUnreachableLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + auto lhs = rewriter.create(op.getLoc(), llvmDstTy, + adaptor.getLhs()); + auto rhs = rewriter.create(op.getLoc(), llvmDstTy, + adaptor.getRhs()); - mlir::LogicalResult - matchAndRewrite(cir::UnreachableOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op); - return mlir::success(); - } -}; + auto diff = + rewriter.create(op.getLoc(), llvmDstTy, lhs, rhs); -class CIRTrapLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + auto ptrTy = mlir::cast(op.getLhs().getType()); + auto typeSize = getTypeSize(ptrTy.getPointee(), *op); - mlir::LogicalResult - matchAndRewrite(cir::TrapOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - rewriter.eraseOp(op); + // Avoid silly division by 1. + auto resultVal = diff.getResult(); + if (typeSize != 1) { + auto typeSizeVal = rewriter.create( + op.getLoc(), llvmDstTy, mlir::IntegerAttr::get(llvmDstTy, typeSize)); - rewriter.create(loc); + if (dstTy.isUnsigned()) + resultVal = rewriter.create(op.getLoc(), llvmDstTy, + diff, typeSizeVal); + else + resultVal = rewriter.create(op.getLoc(), llvmDstTy, + diff, typeSizeVal); + } + rewriter.replaceOp(op, resultVal); + return mlir::success(); +} - // Note that the call to llvm.trap is not a terminator in LLVM dialect. - // So we must emit an additional llvm.unreachable to terminate the current - // block. - rewriter.create(loc); +mlir::LogicalResult CIRToLLVMExpectOpLowering::matchAndRewrite( + cir::ExpectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + std::optional prob = op.getProb(); + if (!prob) + rewriter.replaceOpWithNewOp(op, adaptor.getVal(), + adaptor.getExpected()); + else + rewriter.replaceOpWithNewOp( + op, adaptor.getVal(), adaptor.getExpected(), prob.value()); + return mlir::success(); +} - return mlir::success(); +mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite( + cir::VTableAddrPointOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const auto *converter = getTypeConverter(); + auto targetType = converter->convertType(op.getType()); + mlir::Value symAddr = op.getSymAddr(); + llvm::SmallVector offsets; + mlir::Type eltType; + if (!symAddr) { + symAddr = getValueForVTableSymbol(op, rewriter, getTypeConverter(), + op.getNameAttr(), eltType); + offsets = llvm::SmallVector{0, op.getVtableIndex(), + op.getAddressPointIndex()}; + } else { + // Get indirect vtable address point retrieval + symAddr = adaptor.getSymAddr(); + eltType = converter->convertType(symAddr.getType()); + offsets = llvm::SmallVector{op.getAddressPointIndex()}; } -}; -class CIRInlineAsmOpLowering - : public mlir::OpConversionPattern { + assert(eltType && "Shouldn't ever be missing an eltType here"); + rewriter.replaceOpWithNewOp(op, targetType, eltType, + symAddr, offsets, true); - using mlir::OpConversionPattern::OpConversionPattern; + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::InlineAsmOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type llResTy; - if (op.getNumResults()) - llResTy = getTypeConverter()->convertType(op.getType(0)); +mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite( + cir::StackSaveOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto ptrTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, ptrTy); + return mlir::success(); +} - auto dialect = op.getAsmFlavor(); - auto llDialect = dialect == cir::AsmFlavor::x86_att - ? mlir::LLVM::AsmDialect::AD_ATT - : mlir::LLVM::AsmDialect::AD_Intel; +#define GET_BUILTIN_LOWERING_CLASSES_DEF +#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" +#undef GET_BUILTIN_LOWERING_CLASSES_DEF - std::vector opAttrs; - auto llvmAttrName = mlir::LLVM::InlineAsmOp::getElementTypeAttrName(); +mlir::LogicalResult CIRToLLVMUnreachableOpLowering::matchAndRewrite( + cir::UnreachableOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op); + return mlir::success(); +} - // this is for the lowering to LLVM from LLVm dialect. Otherwise, if we - // don't have the result (i.e. void type as a result of operation), the - // element type attribute will be attached to the whole instruction, but not - // to the operand - if (!op.getNumResults()) - opAttrs.push_back(mlir::Attribute()); +mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite( + cir::TrapOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + rewriter.eraseOp(op); - llvm::SmallVector llvmOperands; - llvm::SmallVector cirOperands; - for (size_t i = 0; i < op.getOperands().size(); ++i) { - auto llvmOps = adaptor.getOperands()[i]; - auto cirOps = op.getOperands()[i]; - llvmOperands.insert(llvmOperands.end(), llvmOps.begin(), llvmOps.end()); - cirOperands.insert(cirOperands.end(), cirOps.begin(), cirOps.end()); - } + rewriter.create(loc); - // so far we infer the llvm dialect element type attr from - // CIR operand type. - for (std::size_t i = 0; i < op.getOperandAttrs().size(); ++i) { - if (!op.getOperandAttrs()[i]) { - opAttrs.push_back(mlir::Attribute()); - continue; - } + // Note that the call to llvm.trap is not a terminator in LLVM dialect. + // So we must emit an additional llvm.unreachable to terminate the current + // block. + rewriter.create(loc); - std::vector attrs; - auto typ = cast(cirOperands[i].getType()); - auto typAttr = mlir::TypeAttr::get( - getTypeConverter()->convertType(typ.getPointee())); + return mlir::success(); +} - attrs.push_back(rewriter.getNamedAttr(llvmAttrName, typAttr)); - auto newDict = rewriter.getDictionaryAttr(attrs); - opAttrs.push_back(newDict); +mlir::LogicalResult CIRToLLVMInlineAsmOpLowering::matchAndRewrite( + cir::InlineAsmOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type llResTy; + if (op.getNumResults()) + llResTy = getTypeConverter()->convertType(op.getType(0)); + + auto dialect = op.getAsmFlavor(); + auto llDialect = dialect == cir::AsmFlavor::x86_att + ? mlir::LLVM::AsmDialect::AD_ATT + : mlir::LLVM::AsmDialect::AD_Intel; + + std::vector opAttrs; + auto llvmAttrName = mlir::LLVM::InlineAsmOp::getElementTypeAttrName(); + + // this is for the lowering to LLVM from LLVm dialect. Otherwise, if we + // don't have the result (i.e. void type as a result of operation), the + // element type attribute will be attached to the whole instruction, but not + // to the operand + if (!op.getNumResults()) + opAttrs.push_back(mlir::Attribute()); + + llvm::SmallVector llvmOperands; + llvm::SmallVector cirOperands; + for (size_t i = 0; i < op.getOperands().size(); ++i) { + auto llvmOps = adaptor.getOperands()[i]; + auto cirOps = op.getOperands()[i]; + llvmOperands.insert(llvmOperands.end(), llvmOps.begin(), llvmOps.end()); + cirOperands.insert(cirOperands.end(), cirOps.begin(), cirOps.end()); + } + + // so far we infer the llvm dialect element type attr from + // CIR operand type. + for (std::size_t i = 0; i < op.getOperandAttrs().size(); ++i) { + if (!op.getOperandAttrs()[i]) { + opAttrs.push_back(mlir::Attribute()); + continue; } - rewriter.replaceOpWithNewOp( - op, llResTy, llvmOperands, op.getAsmStringAttr(), - op.getConstraintsAttr(), op.getSideEffectsAttr(), - /*is_align_stack*/ mlir::UnitAttr(), - mlir::LLVM::AsmDialectAttr::get(getContext(), llDialect), - rewriter.getArrayAttr(opAttrs)); + std::vector attrs; + auto typ = cast(cirOperands[i].getType()); + auto typAttr = + mlir::TypeAttr::get(getTypeConverter()->convertType(typ.getPointee())); - return mlir::success(); + attrs.push_back(rewriter.getNamedAttr(llvmAttrName, typAttr)); + auto newDict = rewriter.getDictionaryAttr(attrs); + opAttrs.push_back(newDict); } -}; -class CIRPrefetchLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + rewriter.replaceOpWithNewOp( + op, llResTy, llvmOperands, op.getAsmStringAttr(), op.getConstraintsAttr(), + op.getSideEffectsAttr(), + /*is_align_stack*/ mlir::UnitAttr(), + mlir::LLVM::AsmDialectAttr::get(getContext(), llDialect), + rewriter.getArrayAttr(opAttrs)); - mlir::LogicalResult - matchAndRewrite(cir::PrefetchOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getAddr(), adaptor.getIsWrite(), adaptor.getLocality(), - /*DataCache*/ 1); - return mlir::success(); - } -}; + return mlir::success(); +} -class CIRSetBitfieldLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMPrefetchOpLowering::matchAndRewrite( + cir::PrefetchOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + op, adaptor.getAddr(), adaptor.getIsWrite(), adaptor.getLocality(), + /*DataCache*/ 1); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::SetBitfieldOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); +mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite( + cir::SetBitfieldOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); - auto info = op.getBitfieldInfo(); - auto size = info.getSize(); - auto offset = info.getOffset(); - auto storageType = info.getStorageType(); - auto context = storageType.getContext(); + auto info = op.getBitfieldInfo(); + auto size = info.getSize(); + auto offset = info.getOffset(); + auto storageType = info.getStorageType(); + auto context = storageType.getContext(); - unsigned storageSize = 0; + unsigned storageSize = 0; - if (auto arTy = mlir::dyn_cast(storageType)) - storageSize = arTy.getSize() * 8; - else if (auto intTy = mlir::dyn_cast(storageType)) - storageSize = intTy.getWidth(); - else - llvm_unreachable( - "Either ArrayType or IntType expected for bitfields storage"); + if (auto arTy = mlir::dyn_cast(storageType)) + storageSize = arTy.getSize() * 8; + else if (auto intTy = mlir::dyn_cast(storageType)) + storageSize = intTy.getWidth(); + else + llvm_unreachable( + "Either ArrayType or IntType expected for bitfields storage"); - auto intType = mlir::IntegerType::get(context, storageSize); - auto srcVal = createIntCast(rewriter, adaptor.getSrc(), intType); - auto srcWidth = storageSize; - auto resultVal = srcVal; + auto intType = mlir::IntegerType::get(context, storageSize); + auto srcVal = createIntCast(rewriter, adaptor.getSrc(), intType); + auto srcWidth = storageSize; + auto resultVal = srcVal; - if (storageSize != size) { - assert(storageSize > size && "Invalid bitfield size."); + if (storageSize != size) { + assert(storageSize > size && "Invalid bitfield size."); - mlir::Value val = rewriter.create( - op.getLoc(), intType, adaptor.getAddr(), /* alignment */ 0, - op.getIsVolatile()); + mlir::Value val = rewriter.create( + op.getLoc(), intType, adaptor.getAddr(), /* alignment */ 0, + op.getIsVolatile()); - srcVal = createAnd(rewriter, srcVal, - llvm::APInt::getLowBitsSet(srcWidth, size)); - resultVal = srcVal; - srcVal = createShL(rewriter, srcVal, offset); + srcVal = + createAnd(rewriter, srcVal, llvm::APInt::getLowBitsSet(srcWidth, size)); + resultVal = srcVal; + srcVal = createShL(rewriter, srcVal, offset); - // Mask out the original value. - val = - createAnd(rewriter, val, + // Mask out the original value. + val = createAnd(rewriter, val, ~llvm::APInt::getBitsSet(srcWidth, offset, offset + size)); - // Or together the unchanged values and the source value. - srcVal = rewriter.create(op.getLoc(), val, srcVal); - } - - rewriter.create(op.getLoc(), srcVal, adaptor.getAddr(), - /* alignment */ 0, op.getIsVolatile()); - - auto resultTy = getTypeConverter()->convertType(op.getType()); - - resultVal = createIntCast(rewriter, resultVal, - mlir::cast(resultTy)); - - if (info.getIsSigned()) { - assert(size <= storageSize); - unsigned highBits = storageSize - size; - - if (highBits) { - resultVal = createShL(rewriter, resultVal, highBits); - resultVal = createAShR(rewriter, resultVal, highBits); - } - } - - rewriter.replaceOp(op, resultVal); - return mlir::success(); + // Or together the unchanged values and the source value. + srcVal = rewriter.create(op.getLoc(), val, srcVal); } -}; - -class CIRGetBitfieldLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(cir::GetBitfieldOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.create(op.getLoc(), srcVal, adaptor.getAddr(), + /* alignment */ 0, op.getIsVolatile()); - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - - auto info = op.getBitfieldInfo(); - auto size = info.getSize(); - auto offset = info.getOffset(); - auto storageType = info.getStorageType(); - auto context = storageType.getContext(); - unsigned storageSize = 0; - - if (auto arTy = mlir::dyn_cast(storageType)) - storageSize = arTy.getSize() * 8; - else if (auto intTy = mlir::dyn_cast(storageType)) - storageSize = intTy.getWidth(); - else - llvm_unreachable( - "Either ArrayType or IntType expected for bitfields storage"); + auto resultTy = getTypeConverter()->convertType(op.getType()); - auto intType = mlir::IntegerType::get(context, storageSize); + resultVal = createIntCast(rewriter, resultVal, + mlir::cast(resultTy)); - mlir::Value val = rewriter.create( - op.getLoc(), intType, adaptor.getAddr(), 0, op.getIsVolatile()); - val = rewriter.create(op.getLoc(), intType, val); - - if (info.getIsSigned()) { - assert(static_cast(offset + size) <= storageSize); - unsigned highBits = storageSize - offset - size; - val = createShL(rewriter, val, highBits); - val = createAShR(rewriter, val, offset + highBits); - } else { - val = createLShR(rewriter, val, offset); + if (info.getIsSigned()) { + assert(size <= storageSize); + unsigned highBits = storageSize - size; - if (static_cast(offset) + size < storageSize) - val = createAnd(rewriter, val, - llvm::APInt::getLowBitsSet(storageSize, size)); + if (highBits) { + resultVal = createShL(rewriter, resultVal, highBits); + resultVal = createAShR(rewriter, resultVal, highBits); } - - auto resTy = getTypeConverter()->convertType(op.getType()); - auto newOp = - createIntCast(rewriter, val, mlir::cast(resTy), - info.getIsSigned()); - rewriter.replaceOp(op, newOp); - return mlir::success(); - } -}; - -class CIRIsConstantOpLowering - : public mlir::OpConversionPattern { - - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::IsConstantOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // FIXME(cir): llvm.intr.is.constant returns i1 value but the LLVM Lowering - // expects that cir.bool type will be lowered as i8 type. - // So we have to insert zext here. - auto isConstantOP = rewriter.create( - op.getLoc(), adaptor.getVal()); - rewriter.replaceOpWithNewOp(op, rewriter.getI8Type(), - isConstantOP); - return mlir::success(); } -}; -class CIRCmpThreeWayOpLowering - : public mlir::OpConversionPattern { -public: - using mlir::OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::CmpThreeWayOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - if (!op.isIntegralComparison() || !op.isStrongOrdering()) { - op.emitError() << "unsupported three-way comparison type"; - return mlir::failure(); - } + rewriter.replaceOp(op, resultVal); + return mlir::success(); +} - auto cmpInfo = op.getInfo(); - assert(cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 && - cmpInfo.getGt() == 1); +mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite( + cir::GetBitfieldOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + + auto info = op.getBitfieldInfo(); + auto size = info.getSize(); + auto offset = info.getOffset(); + auto storageType = info.getStorageType(); + auto context = storageType.getContext(); + unsigned storageSize = 0; + + if (auto arTy = mlir::dyn_cast(storageType)) + storageSize = arTy.getSize() * 8; + else if (auto intTy = mlir::dyn_cast(storageType)) + storageSize = intTy.getWidth(); + else + llvm_unreachable( + "Either ArrayType or IntType expected for bitfields storage"); + + auto intType = mlir::IntegerType::get(context, storageSize); + + mlir::Value val = rewriter.create( + op.getLoc(), intType, adaptor.getAddr(), 0, op.getIsVolatile()); + val = rewriter.create(op.getLoc(), intType, val); + + if (info.getIsSigned()) { + assert(static_cast(offset + size) <= storageSize); + unsigned highBits = storageSize - offset - size; + val = createShL(rewriter, val, highBits); + val = createAShR(rewriter, val, offset + highBits); + } else { + val = createLShR(rewriter, val, offset); - auto operandTy = mlir::cast(op.getLhs().getType()); - auto resultTy = op.getType(); - auto llvmIntrinsicName = getLLVMIntrinsicName( - operandTy.isSigned(), operandTy.getWidth(), resultTy.getWidth()); + if (static_cast(offset) + size < storageSize) + val = createAnd(rewriter, val, + llvm::APInt::getLowBitsSet(storageSize, size)); + } - rewriter.setInsertionPoint(op); + auto resTy = getTypeConverter()->convertType(op.getType()); + auto newOp = createIntCast( + rewriter, val, mlir::cast(resTy), info.getIsSigned()); + rewriter.replaceOp(op, newOp); + return mlir::success(); +} - auto llvmLhs = adaptor.getLhs(); - auto llvmRhs = adaptor.getRhs(); - auto llvmResultTy = getTypeConverter()->convertType(resultTy); - auto callIntrinsicOp = - createCallLLVMIntrinsicOp(rewriter, op.getLoc(), llvmIntrinsicName, - llvmResultTy, {llvmLhs, llvmRhs}); +mlir::LogicalResult CIRToLLVMIsConstantOpLowering::matchAndRewrite( + cir::IsConstantOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // FIXME(cir): llvm.intr.is.constant returns i1 value but the LLVM Lowering + // expects that cir.bool type will be lowered as i8 type. + // So we have to insert zext here. + auto isConstantOP = + rewriter.create(op.getLoc(), adaptor.getVal()); + rewriter.replaceOpWithNewOp(op, rewriter.getI8Type(), + isConstantOP); + return mlir::success(); +} - rewriter.replaceOp(op, callIntrinsicOp); - return mlir::success(); +mlir::LogicalResult CIRToLLVMCmpThreeWayOpLowering::matchAndRewrite( + cir::CmpThreeWayOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + if (!op.isIntegralComparison() || !op.isStrongOrdering()) { + op.emitError() << "unsupported three-way comparison type"; + return mlir::failure(); } -private: - static std::string getLLVMIntrinsicName(bool signedCmp, unsigned operandWidth, - unsigned resultWidth) { - // The intrinsic's name takes the form: - // `llvm..i.i` + auto cmpInfo = op.getInfo(); + assert(cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 && cmpInfo.getGt() == 1); - std::string result = "llvm."; + auto operandTy = mlir::cast(op.getLhs().getType()); + auto resultTy = op.getType(); + auto llvmIntrinsicName = getLLVMIntrinsicName( + operandTy.isSigned(), operandTy.getWidth(), resultTy.getWidth()); - if (signedCmp) - result.append("scmp."); - else - result.append("ucmp."); + rewriter.setInsertionPoint(op); - // Result type part. - result.push_back('i'); - result.append(std::to_string(resultWidth)); - result.push_back('.'); + auto llvmLhs = adaptor.getLhs(); + auto llvmRhs = adaptor.getRhs(); + auto llvmResultTy = getTypeConverter()->convertType(resultTy); + auto callIntrinsicOp = + createCallLLVMIntrinsicOp(rewriter, op.getLoc(), llvmIntrinsicName, + llvmResultTy, {llvmLhs, llvmRhs}); - // Operand type part. - result.push_back('i'); - result.append(std::to_string(operandWidth)); + rewriter.replaceOp(op, callIntrinsicOp); + return mlir::success(); +} - return result; - } -}; +std::string CIRToLLVMCmpThreeWayOpLowering::getLLVMIntrinsicName( + bool signedCmp, unsigned operandWidth, unsigned resultWidth) { + // The intrinsic's name takes the form: + // `llvm..i.i` -class CIRReturnAddrOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + std::string result = "llvm."; - mlir::LogicalResult - matchAndRewrite(cir::ReturnAddrOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - replaceOpWithCallLLVMIntrinsicOp(rewriter, op, "llvm.returnaddress", - llvmPtrTy, adaptor.getOperands()); - return mlir::success(); - } -}; + if (signedCmp) + result.append("scmp."); + else + result.append("ucmp."); -class CIRClearCacheOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ClearCacheOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto begin = adaptor.getBegin(); - auto end = adaptor.getEnd(); - auto intrinNameAttr = - mlir::StringAttr::get(op.getContext(), "llvm.clear_cache"); - rewriter.replaceOpWithNewOp( - op, mlir::Type{}, intrinNameAttr, mlir::ValueRange{begin, end}); + // Result type part. + result.push_back('i'); + result.append(std::to_string(resultWidth)); + result.push_back('.'); - return mlir::success(); - } -}; + // Operand type part. + result.push_back('i'); + result.append(std::to_string(operandWidth)); -class CIREhTypeIdOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + return result; +} - mlir::LogicalResult - matchAndRewrite(cir::EhTypeIdOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Value addrOp = rewriter.create( - op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), - op.getTypeSymAttr()); - mlir::LLVM::CallIntrinsicOp newOp = createCallLLVMIntrinsicOp( - rewriter, op.getLoc(), "llvm.eh.typeid.for.p0", rewriter.getI32Type(), - mlir::ValueRange{addrOp}); - rewriter.replaceOp(op, newOp); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite( + cir::ReturnAddrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + replaceOpWithCallLLVMIntrinsicOp(rewriter, op, "llvm.returnaddress", + llvmPtrTy, adaptor.getOperands()); + return mlir::success(); +} -class CIRCatchParamOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::CatchParamOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - if (op.isBegin()) { - // Get or create `declare ptr @__cxa_begin_catch(ptr)` - StringRef fnName = "__cxa_begin_catch"; - auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {llvmPtrTy}, - /*isVarArg=*/false); - getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); - rewriter.replaceOpWithNewOp( - op, mlir::TypeRange{llvmPtrTy}, fnName, - mlir::ValueRange{adaptor.getExceptionPtr()}); - return mlir::success(); - } else if (op.isEnd()) { - StringRef fnName = "__cxa_end_catch"; - auto fnTy = mlir::LLVM::LLVMFunctionType::get( - mlir::LLVM::LLVMVoidType::get(rewriter.getContext()), {}, - /*isVarArg=*/false); - getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); - rewriter.create(op.getLoc(), mlir::TypeRange{}, - fnName, mlir::ValueRange{}); - rewriter.eraseOp(op); - return mlir::success(); - } - llvm_unreachable("only begin/end supposed to make to lowering stage"); - return mlir::failure(); - } -}; +mlir::LogicalResult CIRToLLVMClearCacheOpLowering::matchAndRewrite( + cir::ClearCacheOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto begin = adaptor.getBegin(); + auto end = adaptor.getEnd(); + auto intrinNameAttr = + mlir::StringAttr::get(op.getContext(), "llvm.clear_cache"); + rewriter.replaceOpWithNewOp( + op, mlir::Type{}, intrinNameAttr, mlir::ValueRange{begin, end}); -class CIRResumeOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::ResumeOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // %lpad.val = insertvalue { ptr, i32 } poison, ptr %exception_ptr, 0 - // %lpad.val2 = insertvalue { ptr, i32 } %lpad.val, i32 %selector, 1 - // resume { ptr, i32 } %lpad.val2 - SmallVector slotIdx = {0}; - SmallVector selectorIdx = {1}; - auto llvmLandingPadStructTy = getLLVMLandingPadStructTy(rewriter); - mlir::Value poison = rewriter.create( - op.getLoc(), llvmLandingPadStructTy); - - mlir::Value slot = rewriter.create( - op.getLoc(), poison, adaptor.getExceptionPtr(), slotIdx); - mlir::Value selector = rewriter.create( - op.getLoc(), slot, adaptor.getTypeId(), selectorIdx); - - rewriter.replaceOpWithNewOp(op, selector); - return mlir::success(); - } -}; + return mlir::success(); +} -class CIRAllocExceptionOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMEhTypeIdOpLowering::matchAndRewrite( + cir::EhTypeIdOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Value addrOp = rewriter.create( + op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + op.getTypeSymAttr()); + mlir::LLVM::CallIntrinsicOp newOp = createCallLLVMIntrinsicOp( + rewriter, op.getLoc(), "llvm.eh.typeid.for.p0", rewriter.getI32Type(), + mlir::ValueRange{addrOp}); + rewriter.replaceOp(op, newOp); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::AllocExceptionOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Get or create `declare ptr @__cxa_allocate_exception(i64)` - StringRef fnName = "__cxa_allocate_exception"; +mlir::LogicalResult CIRToLLVMCatchParamOpLowering::matchAndRewrite( + cir::CatchParamOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + if (op.isBegin()) { + // Get or create `declare ptr @__cxa_begin_catch(ptr)` + StringRef fnName = "__cxa_begin_catch"; auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - auto int64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); - auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {int64Ty}, + auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {llvmPtrTy}, /*isVarArg=*/false); getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); - auto size = rewriter.create(op.getLoc(), - adaptor.getSizeAttr()); rewriter.replaceOpWithNewOp( - op, mlir::TypeRange{llvmPtrTy}, fnName, mlir::ValueRange{size}); + op, mlir::TypeRange{llvmPtrTy}, fnName, + mlir::ValueRange{adaptor.getExceptionPtr()}); return mlir::success(); - } -}; - -class CIRFreeExceptionOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::FreeExceptionOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Get or create `declare void @__cxa_free_exception(ptr)` - StringRef fnName = "__cxa_free_exception"; - auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - auto voidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); - auto fnTy = mlir::LLVM::LLVMFunctionType::get(voidTy, {llvmPtrTy}, - /*isVarArg=*/false); + } else if (op.isEnd()) { + StringRef fnName = "__cxa_end_catch"; + auto fnTy = mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMVoidType::get(rewriter.getContext()), {}, + /*isVarArg=*/false); getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); - rewriter.replaceOpWithNewOp( - op, mlir::TypeRange{}, fnName, mlir::ValueRange{adaptor.getPtr()}); + rewriter.create(op.getLoc(), mlir::TypeRange{}, fnName, + mlir::ValueRange{}); + rewriter.eraseOp(op); return mlir::success(); } -}; + llvm_unreachable("only begin/end supposed to make to lowering stage"); + return mlir::failure(); +} -class CIRThrowOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +mlir::LogicalResult CIRToLLVMResumeOpLowering::matchAndRewrite( + cir::ResumeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // %lpad.val = insertvalue { ptr, i32 } poison, ptr %exception_ptr, 0 + // %lpad.val2 = insertvalue { ptr, i32 } %lpad.val, i32 %selector, 1 + // resume { ptr, i32 } %lpad.val2 + SmallVector slotIdx = {0}; + SmallVector selectorIdx = {1}; + auto llvmLandingPadStructTy = getLLVMLandingPadStructTy(rewriter); + mlir::Value poison = rewriter.create( + op.getLoc(), llvmLandingPadStructTy); + + mlir::Value slot = rewriter.create( + op.getLoc(), poison, adaptor.getExceptionPtr(), slotIdx); + mlir::Value selector = rewriter.create( + op.getLoc(), slot, adaptor.getTypeId(), selectorIdx); + + rewriter.replaceOpWithNewOp(op, selector); + return mlir::success(); +} - mlir::LogicalResult - matchAndRewrite(cir::ThrowOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // Get or create `declare void @__cxa_throw(ptr, ptr, ptr)` - StringRef fnName = "__cxa_throw"; - auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - auto voidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); - auto fnTy = mlir::LLVM::LLVMFunctionType::get( - voidTy, {llvmPtrTy, llvmPtrTy, llvmPtrTy}, - /*isVarArg=*/false); - getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); - mlir::Value typeInfo = rewriter.create( - op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), - adaptor.getTypeInfoAttr()); +mlir::LogicalResult CIRToLLVMAllocExceptionOpLowering::matchAndRewrite( + cir::AllocExceptionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Get or create `declare ptr @__cxa_allocate_exception(i64)` + StringRef fnName = "__cxa_allocate_exception"; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto int64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); + auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {int64Ty}, + /*isVarArg=*/false); + getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); + auto size = rewriter.create(op.getLoc(), + adaptor.getSizeAttr()); + rewriter.replaceOpWithNewOp( + op, mlir::TypeRange{llvmPtrTy}, fnName, mlir::ValueRange{size}); + return mlir::success(); +} - mlir::Value dtor; - if (op.getDtor()) { - dtor = rewriter.create( - op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), - adaptor.getDtorAttr()); - } else { - dtor = rewriter.create( - op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext())); - } - rewriter.replaceOpWithNewOp( - op, mlir::TypeRange{}, fnName, - mlir::ValueRange{adaptor.getExceptionPtr(), typeInfo, dtor}); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMFreeExceptionOpLowering::matchAndRewrite( + cir::FreeExceptionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Get or create `declare void @__cxa_free_exception(ptr)` + StringRef fnName = "__cxa_free_exception"; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto voidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); + auto fnTy = mlir::LLVM::LLVMFunctionType::get(voidTy, {llvmPtrTy}, + /*isVarArg=*/false); + getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); + rewriter.replaceOpWithNewOp( + op, mlir::TypeRange{}, fnName, mlir::ValueRange{adaptor.getPtr()}); + return mlir::success(); +} -class CIRIsFPClassOpLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::IsFPClassOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto src = adaptor.getSrc(); - auto flags = adaptor.getFlags(); - auto retTy = rewriter.getI1Type(); - - auto loc = op->getLoc(); - - auto intrinsic = - rewriter.create(loc, retTy, src, flags); - // FIMXE: CIR now will convert cir::BoolType to i8 type unconditionally. - // Remove this conversion after we fix - // https://github.com/llvm/clangir/issues/480 - auto converted = rewriter.create( - loc, rewriter.getI8Type(), intrinsic->getResult(0)); - - rewriter.replaceOp(op, converted); - return mlir::success(); +mlir::LogicalResult CIRToLLVMThrowOpLowering::matchAndRewrite( + cir::ThrowOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Get or create `declare void @__cxa_throw(ptr, ptr, ptr)` + StringRef fnName = "__cxa_throw"; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + auto voidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); + auto fnTy = mlir::LLVM::LLVMFunctionType::get( + voidTy, {llvmPtrTy, llvmPtrTy, llvmPtrTy}, + /*isVarArg=*/false); + getOrCreateLLVMFuncOp(rewriter, op, fnName, fnTy); + mlir::Value typeInfo = rewriter.create( + op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + adaptor.getTypeInfoAttr()); + + mlir::Value dtor; + if (op.getDtor()) { + dtor = rewriter.create( + op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + adaptor.getDtorAttr()); + } else { + dtor = rewriter.create( + op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext())); } -}; + rewriter.replaceOpWithNewOp( + op, mlir::TypeRange{}, fnName, + mlir::ValueRange{adaptor.getExceptionPtr(), typeInfo, dtor}); + return mlir::success(); +} -class CIRPtrMaskOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::PtrMaskOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - // FIXME: We'd better to lower to mlir::LLVM::PtrMaskOp if it exists. - // So we have to make it manually here by following: - // https://llvm.org/docs/LangRef.html#llvm-ptrmask-intrinsic - auto loc = op.getLoc(); - auto mask = op.getMask(); - - auto moduleOp = op->getParentOfType(); - mlir::DataLayout layout(moduleOp); - auto iPtrIdxValue = layout.getTypeSizeInBits(mask.getType()); - auto iPtrIdx = mlir::IntegerType::get(moduleOp->getContext(), iPtrIdxValue); - - auto intPtr = rewriter.create( - loc, iPtrIdx, adaptor.getPtr()); // this may truncate - mlir::Value masked = - rewriter.create(loc, intPtr, adaptor.getMask()); - mlir::Value diff = rewriter.create(loc, intPtr, masked); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - mlir::IntegerType::get(moduleOp->getContext(), 8), adaptor.getPtr(), - diff); - return mlir::success(); - } -}; +mlir::LogicalResult CIRToLLVMIsFPClassOpLowering::matchAndRewrite( + cir::IsFPClassOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto flags = adaptor.getFlags(); + auto retTy = rewriter.getI1Type(); -class CIRAbsOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(cir::AbsOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto resTy = this->getTypeConverter()->convertType(op.getType()); - auto absOp = rewriter.create( - op.getLoc(), resTy, adaptor.getOperands()[0], adaptor.getPoison()); - rewriter.replaceOp(op, absOp); - return mlir::success(); - } -}; -class CIRSignBitOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; + auto loc = op->getLoc(); - mlir::LogicalResult - matchAndRewrite(cir::SignBitOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - assert(!::cir::MissingFeatures::isPPC_FP128Ty()); + auto intrinsic = + rewriter.create(loc, retTy, src, flags); + // FIMXE: CIR now will convert cir::BoolType to i8 type unconditionally. + // Remove this conversion after we fix + // https://github.com/llvm/clangir/issues/480 + auto converted = rewriter.create( + loc, rewriter.getI8Type(), intrinsic->getResult(0)); - mlir::DataLayout layout(op->getParentOfType()); - int width = layout.getTypeSizeInBits(op.getInput().getType()); - if (auto longDoubleType = - mlir::dyn_cast(op.getInput().getType())) { - if (mlir::isa(longDoubleType.getUnderlying())) { - // If the underlying type of LongDouble is FP80Type, - // DataLayout::getTypeSizeInBits returns 128. - // See https://github.com/llvm/clangir/issues/1057. - // Set the width to 80 manually. - width = 80; - } + rewriter.replaceOp(op, converted); + return mlir::success(); +} + +mlir::LogicalResult CIRToLLVMAbsOpLowering::matchAndRewrite( + cir::AbsOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = this->getTypeConverter()->convertType(op.getType()); + auto absOp = rewriter.create( + op.getLoc(), resTy, adaptor.getOperands()[0], adaptor.getPoison()); + rewriter.replaceOp(op, absOp); + return mlir::success(); +} + +mlir::LogicalResult CIRToLLVMPtrMaskOpLowering::matchAndRewrite( + cir::PtrMaskOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // FIXME: We'd better to lower to mlir::LLVM::PtrMaskOp if it exists. + // So we have to make it manually here by following: + // https://llvm.org/docs/LangRef.html#llvm-ptrmask-intrinsic + auto loc = op.getLoc(); + auto mask = op.getMask(); + + auto moduleOp = op->getParentOfType(); + mlir::DataLayout layout(moduleOp); + auto iPtrIdxValue = layout.getTypeSizeInBits(mask.getType()); + auto iPtrIdx = mlir::IntegerType::get(moduleOp->getContext(), iPtrIdxValue); + + auto intPtr = rewriter.create( + loc, iPtrIdx, adaptor.getPtr()); // this may truncate + mlir::Value masked = + rewriter.create(loc, intPtr, adaptor.getMask()); + mlir::Value diff = rewriter.create(loc, intPtr, masked); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mlir::IntegerType::get(moduleOp->getContext(), 8), adaptor.getPtr(), + diff); + return mlir::success(); +} + +mlir::LogicalResult CIRToLLVMSignBitOpLowering::matchAndRewrite( + cir::SignBitOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + assert(!::cir::MissingFeatures::isPPC_FP128Ty()); + + mlir::DataLayout layout(op->getParentOfType()); + int width = layout.getTypeSizeInBits(op.getInput().getType()); + if (auto longDoubleType = + mlir::dyn_cast(op.getInput().getType())) { + if (mlir::isa(longDoubleType.getUnderlying())) { + // If the underlying type of LongDouble is FP80Type, + // DataLayout::getTypeSizeInBits returns 128. + // See https://github.com/llvm/clangir/issues/1057. + // Set the width to 80 manually. + width = 80; } - auto intTy = mlir::IntegerType::get(rewriter.getContext(), width); - auto bitcast = rewriter.create(op->getLoc(), intTy, - adaptor.getInput()); - auto zero = rewriter.create(op->getLoc(), intTy, 0); - auto cmpResult = rewriter.create( - op.getLoc(), mlir::LLVM::ICmpPredicate::slt, bitcast.getResult(), zero); - auto converted = rewriter.create( - op.getLoc(), mlir::IntegerType::get(rewriter.getContext(), 32), - cmpResult); - rewriter.replaceOp(op, converted); - return mlir::success(); } -}; + auto intTy = mlir::IntegerType::get(rewriter.getContext(), width); + auto bitcast = rewriter.create(op->getLoc(), intTy, + adaptor.getInput()); + auto zero = rewriter.create(op->getLoc(), intTy, 0); + auto cmpResult = rewriter.create( + op.getLoc(), mlir::LLVM::ICmpPredicate::slt, bitcast.getResult(), zero); + auto converted = rewriter.create( + op.getLoc(), mlir::IntegerType::get(rewriter.getContext(), 32), + cmpResult); + rewriter.replaceOp(op, converted); + return mlir::success(); +} void populateCIRToLLVMConversionPatterns( mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, @@ -4393,57 +3842,62 @@ void populateCIRToLLVMConversionPatterns( llvm::StringMap &stringGlobalsMap, llvm::StringMap &argStringGlobalsMap, llvm::MapVector &argsVarMap) { - patterns.add(patterns.getContext()); - patterns.add(converter, dataLayout, stringGlobalsMap, - argStringGlobalsMap, argsVarMap, - patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(converter, dataLayout, + stringGlobalsMap, argStringGlobalsMap, + argsVarMap, patterns.getContext()); patterns.add< - CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering, - CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering, - CIRBitParityOpLowering, CIRBitPopcountOpLowering, - CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering, - CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering, - CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering, - CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering, - CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering, - CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering, - CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering, - CIRComplexRealOpLowering, CIRComplexImagOpLowering, - CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering, - CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, - CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering, - CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering, - CIRMemCpyOpLowering, CIRMemChrOpLowering, CIRMemCpyInlineOpLowering, - CIRFAbsOpLowering, CIRExpectOpLowering, CIRVTableAddrPointOpLowering, - CIRVectorCreateLowering, CIRVectorCmpOpLowering, CIRVectorSplatLowering, - CIRVectorTernaryLowering, CIRVectorShuffleIntsLowering, - CIRVectorShuffleVecLowering, CIRStackSaveLowering, CIRUnreachableLowering, - CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering, - CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering, - CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRMemCpyOpLowering, - CIRFAbsOpLowering, CIRExpectOpLowering, CIRVTableAddrPointOpLowering, - CIRVectorCreateLowering, CIRVectorCmpOpLowering, CIRVectorSplatLowering, - CIRVectorTernaryLowering, CIRVectorShuffleIntsLowering, - CIRVectorShuffleVecLowering, CIRStackSaveLowering, CIRUnreachableLowering, - CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering, - CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering, - CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, - CIRReturnAddrOpLowering, CIRClearCacheOpLowering, CIREhTypeIdOpLowering, - CIRCatchParamOpLowering, CIRResumeOpLowering, CIRAllocExceptionOpLowering, - CIRFreeExceptionOpLowering, CIRThrowOpLowering, CIRIntrinsicCallLowering, - CIRAssumeLowering, CIRAssumeAlignedLowering, CIRAssumeSepStorageLowering, - CIRBaseClassAddrOpLowering, CIRDerivedClassAddrOpLowering, - CIRVTTAddrPointOpLowering, CIRIsFPClassOpLowering, CIRAbsOpLowering, - CIRMemMoveOpLowering, CIRMemsetOpLowering, CIRSignBitOpLowering, - CIRPtrMaskOpLowering + CIRToLLVMCmpOpLowering, CIRToLLVMSelectOpLowering, + CIRToLLVMBitClrsbOpLowering, CIRToLLVMBitClzOpLowering, + CIRToLLVMBitCtzOpLowering, CIRToLLVMBitFfsOpLowering, + CIRToLLVMBitParityOpLowering, CIRToLLVMBitPopcountOpLowering, + CIRToLLVMAtomicCmpXchgLowering, CIRToLLVMAtomicXchgLowering, + CIRToLLVMAtomicFetchLowering, CIRToLLVMByteswapOpLowering, + CIRToLLVMRotateOpLowering, CIRToLLVMBrCondOpLowering, + CIRToLLVMPtrStrideOpLowering, CIRToLLVMCallOpLowering, + CIRToLLVMTryCallOpLowering, CIRToLLVMEhInflightOpLowering, + CIRToLLVMUnaryOpLowering, CIRToLLVMBinOpLowering, + CIRToLLVMBinOpOverflowOpLowering, CIRToLLVMShiftOpLowering, + CIRToLLVMLoadOpLowering, CIRToLLVMConstantOpLowering, + CIRToLLVMStoreOpLowering, CIRToLLVMFuncOpLowering, + CIRToLLVMCastOpLowering, CIRToLLVMGlobalOpLowering, + CIRToLLVMGetGlobalOpLowering, CIRToLLVMComplexCreateOpLowering, + CIRToLLVMComplexRealOpLowering, CIRToLLVMComplexImagOpLowering, + CIRToLLVMComplexRealPtrOpLowering, CIRToLLVMComplexImagPtrOpLowering, + CIRToLLVMVAStartOpLowering, CIRToLLVMVAEndOpLowering, + CIRToLLVMVACopyOpLowering, CIRToLLVMVAArgOpLowering, + CIRToLLVMBrOpLowering, CIRToLLVMGetMemberOpLowering, + CIRToLLVMGetRuntimeMemberOpLowering, CIRToLLVMSwitchFlatOpLowering, + CIRToLLVMPtrDiffOpLowering, CIRToLLVMCopyOpLowering, + CIRToLLVMMemCpyOpLowering, CIRToLLVMMemChrOpLowering, + CIRToLLVMAbsOpLowering, CIRToLLVMExpectOpLowering, + CIRToLLVMVTableAddrPointOpLowering, CIRToLLVMVecCreateOpLowering, + CIRToLLVMVecCmpOpLowering, CIRToLLVMVecSplatOpLowering, + CIRToLLVMVecTernaryOpLowering, CIRToLLVMVecShuffleDynamicOpLowering, + CIRToLLVMVecShuffleOpLowering, CIRToLLVMStackSaveOpLowering, + CIRToLLVMUnreachableOpLowering, CIRToLLVMTrapOpLowering, + CIRToLLVMInlineAsmOpLowering, CIRToLLVMSetBitfieldOpLowering, + CIRToLLVMGetBitfieldOpLowering, CIRToLLVMPrefetchOpLowering, + CIRToLLVMObjSizeOpLowering, CIRToLLVMIsConstantOpLowering, + CIRToLLVMCmpThreeWayOpLowering, CIRToLLVMMemCpyOpLowering, + CIRToLLVMIsConstantOpLowering, CIRToLLVMCmpThreeWayOpLowering, + CIRToLLVMReturnAddrOpLowering, CIRToLLVMClearCacheOpLowering, + CIRToLLVMEhTypeIdOpLowering, CIRToLLVMCatchParamOpLowering, + CIRToLLVMResumeOpLowering, CIRToLLVMAllocExceptionOpLowering, + CIRToLLVMFreeExceptionOpLowering, CIRToLLVMThrowOpLowering, + CIRToLLVMLLVMIntrinsicCallOpLowering, CIRToLLVMAssumeOpLowering, + CIRToLLVMAssumeAlignedOpLowering, CIRToLLVMAssumeSepStorageOpLowering, + CIRToLLVMBaseClassAddrOpLowering, CIRToLLVMDerivedClassAddrOpLowering, + CIRToLLVMVTTAddrPointOpLowering, CIRToLLVMIsFPClassOpLowering, + CIRToLLVMAbsOpLowering, CIRToLLVMMemMoveOpLowering, + CIRToLLVMMemSetOpLowering, CIRToLLVMMemCpyInlineOpLowering, + CIRToLLVMSignBitOpLowering, CIRToLLVMPtrMaskOpLowering #define GET_BUILTIN_LOWERING_LIST #include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" #undef GET_BUILTIN_LOWERING_LIST >(converter, patterns.getContext()); } -namespace { - std::unique_ptr prepareLowerModule(mlir::ModuleOp module) { mlir::PatternRewriter rewriter{module->getContext()}; // If the triple is not present, e.g. CIR modules parsed from text, we @@ -4576,9 +4030,8 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, return mlir::LLVM::LLVMVoidType::get(type.getContext()); }); } -} // namespace -static void buildCtorDtorList( +void buildCtorDtorList( mlir::ModuleOp module, StringRef globalXtorName, StringRef llvmXtorName, llvm::function_ref(mlir::Attribute)> createXtor) { llvm::SmallVector, 2> globalXtors; diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h new file mode 100644 index 000000000000..d1488ec8f6f5 --- /dev/null +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -0,0 +1,1024 @@ +//====- LowerToLLVM.h - Lowering from CIR to LLVMIR -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LowerModule.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace cir { +namespace direct { +mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter); + +mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage); + +mlir::LLVM::CConv convertCallingConv(cir::CallingConv callinvConv); + +void buildCtorDtorList( + mlir::ModuleOp module, mlir::StringRef globalXtorName, + mlir::StringRef llvmXtorName, + llvm::function_ref(mlir::Attribute)> + createXtor); + +void populateCIRToLLVMConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, + mlir::DataLayout &dataLayout, + llvm::StringMap &stringGlobalsMap, + llvm::StringMap &argStringGlobalsMap, + llvm::MapVector &argsVarMap); + +std::unique_ptr prepareLowerModule(mlir::ModuleOp module); + +void prepareTypeConverter(mlir::LLVMTypeConverter &converter, + mlir::DataLayout &dataLayout, + cir::LowerModule *lowerModule); + +mlir::LLVM::AtomicOrdering +getLLVMMemOrder(std::optional &memorder); + +mlir::LLVM::AtomicOrdering getLLVMAtomicOrder(cir::MemOrder memo); + +mlir::LLVM::CallIntrinsicOp +createCallLLVMIntrinsicOp(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, const llvm::Twine &intrinsicName, + mlir::Type resultTy, mlir::ValueRange operands); + +mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp( + mlir::ConversionPatternRewriter &rewriter, mlir::Operation *op, + const llvm::Twine &intrinsicName, mlir::Type resultTy, + mlir::ValueRange operands); + +mlir::Value createLLVMBitOp(mlir::Location loc, + const llvm::Twine &llvmIntrinBaseName, + mlir::Type resultTy, mlir::Value operand, + std::optional poisonZeroInputFlag, + mlir::ConversionPatternRewriter &rewriter); + +class CIRToLLVMCopyOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::CopyOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMMemCpyOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::MemCpyOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMMemChrOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::MemChrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMMemMoveOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::MemMoveOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMMemCpyInlineOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::MemCpyInlineOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; +}; + +class CIRToLLVMMemSetOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::MemSetOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMPtrStrideOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::PtrStrideOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBaseClassAddrOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BaseClassAddrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMDerivedClassAddrOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::DerivedClassAddrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVTTAddrPointOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VTTAddrPointOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBrCondOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BrCondOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern { + mlir::Type convertTy(mlir::Type ty) const; + +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::CastOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMReturnOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ReturnOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMCallOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::CallOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMTryCallOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::TryCallOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMEhInflightOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::EhInflightOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAllocaOpLowering + : public mlir::OpConversionPattern { + mlir::DataLayout const &dataLayout; + // Track globals created for annotation related strings + llvm::StringMap &stringGlobalsMap; + // Track globals created for annotation arg related strings. + // They are different from annotation strings, as strings used in args + // are not in llvmMetadataSectionName, and also has aligment 1. + llvm::StringMap &argStringGlobalsMap; + // Track globals created for annotation args. + llvm::MapVector &argsVarMap; + +public: + CIRToLLVMAllocaOpLowering( + mlir::TypeConverter const &typeConverter, + mlir::DataLayout const &dataLayout, + llvm::StringMap &stringGlobalsMap, + llvm::StringMap &argStringGlobalsMap, + llvm::MapVector &argsVarMap, + mlir::MLIRContext *context) + : OpConversionPattern(typeConverter, context), + dataLayout(dataLayout), stringGlobalsMap(stringGlobalsMap), + argStringGlobalsMap(argStringGlobalsMap), argsVarMap(argsVarMap) {} + + using mlir::OpConversionPattern::OpConversionPattern; + + void buildAllocaAnnotations(mlir::LLVM::AllocaOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, + mlir::ArrayAttr annotationValuesArray) const; + + mlir::LogicalResult + matchAndRewrite(cir::AllocaOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMLoadOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::LoadOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMStoreOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::StoreOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMConstantOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ConstantOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVecCreateOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecCreateOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVecCmpOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecCmpOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVecSplatOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecSplatOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVecTernaryOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecTernaryOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVecShuffleOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecShuffleOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVecShuffleDynamicOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern< + cir::VecShuffleDynamicOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecShuffleDynamicOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVAStartOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VAStartOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVAEndOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VAEndOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVACopyOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VACopyOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVAArgOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VAArgOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMFuncOpLowering : public mlir::OpConversionPattern { + static mlir::StringRef getLinkageAttrNameString(); + + void lowerFuncAttributes( + cir::FuncOp func, bool filterArgAndResAttrs, + mlir::SmallVectorImpl &result) const; + + void + lowerFuncOpenCLKernelMetadata(mlir::NamedAttribute &extraAttrsEntry) const; + +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::FuncOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMGetGlobalOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::GetGlobalOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMComplexCreateOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ComplexCreateOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMComplexRealOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ComplexRealOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMComplexImagOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ComplexImagOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMComplexRealPtrOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ComplexRealPtrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMComplexImagPtrOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ComplexImagPtrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMSwitchFlatOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SwitchFlatOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMGlobalOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::GlobalOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; + +private: + void setupRegionInitializedLLVMGlobalOp( + cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const; + + mutable mlir::LLVM::ComdatOp comdatOp = nullptr; + static void addComdat(mlir::LLVM::GlobalOp &op, + mlir::LLVM::ComdatOp &comdatOp, + mlir::OpBuilder &builder, mlir::ModuleOp &module); +}; + +class CIRToLLVMUnaryOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::UnaryOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBinOpLowering : public mlir::OpConversionPattern { + mlir::LLVM::IntegerOverflowFlags getIntOverflowFlag(cir::BinOp op) const; + +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BinOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBinOpOverflowOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BinOpOverflowOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; + +private: + static std::string getLLVMIntrinName(cir::BinOpOverflowKind opKind, + bool isSigned, unsigned width); + + struct EncompassedTypeInfo { + bool sign; + unsigned width; + }; + + static EncompassedTypeInfo computeEncompassedTypeWidth(cir::IntType operandTy, + cir::IntType resultTy); +}; + +class CIRToLLVMShiftOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ShiftOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::CmpOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMLLVMIntrinsicCallOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern< + cir::LLVMIntrinsicCallOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::LLVMIntrinsicCallOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAssumeOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AssumeOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAssumeAlignedOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AssumeAlignedOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAssumeSepStorageOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AssumeSepStorageOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBitClrsbOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitClrsbOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMObjSizeOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ObjSizeOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBitClzOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitClzOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBitCtzOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitCtzOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBitFfsOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitFfsOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBitParityOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitParityOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBitPopcountOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitPopcountOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAtomicCmpXchgLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AtomicCmpXchg op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAtomicXchgLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AtomicXchg op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAtomicFetchLowering + : public mlir::OpConversionPattern { + mlir::Value buildPostOp(cir::AtomicFetch op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, + mlir::Value rmwVal, bool isInt) const; + + mlir::Value buildMinMaxPostOp(cir::AtomicFetch op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter, + mlir::Value rmwVal, bool isSigned) const; + + llvm::StringLiteral getLLVMBinop(cir::AtomicFetchKind k, bool isInt) const; + + mlir::LLVM::AtomicBinOp getLLVMAtomicBinOp(cir::AtomicFetchKind k, bool isInt, + bool isSignedInt) const; + +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AtomicFetch op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMByteswapOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ByteswapOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMRotateOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::RotateOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMSelectOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SelectOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMGetMemberOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::GetMemberOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMGetRuntimeMemberOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::GetRuntimeMemberOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMPtrDiffOpLowering + : public mlir::OpConversionPattern { + uint64_t getTypeSize(mlir::Type type, mlir::Operation &op) const; + +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::PtrDiffOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMExpectOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ExpectOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMVTableAddrPointOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VTableAddrPointOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMStackSaveOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::StackSaveOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMUnreachableOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::UnreachableOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMTrapOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::TrapOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMInlineAsmOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::InlineAsmOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMPrefetchOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::PrefetchOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMSetBitfieldOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SetBitfieldOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMGetBitfieldOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::GetBitfieldOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMIsConstantOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::IsConstantOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMCmpThreeWayOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::CmpThreeWayOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; + +private: + static std::string getLLVMIntrinsicName(bool signedCmp, unsigned operandWidth, + unsigned resultWidth); +}; + +class CIRToLLVMReturnAddrOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ReturnAddrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMClearCacheOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ClearCacheOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMEhTypeIdOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::EhTypeIdOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMCatchParamOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::CatchParamOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMResumeOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ResumeOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAllocExceptionOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AllocExceptionOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMFreeExceptionOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::FreeExceptionOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMThrowOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ThrowOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMIsFPClassOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::IsFPClassOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMPtrMaskOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::PtrMaskOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMAbsOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::AbsOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + +class CIRToLLVMSignBitOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SignBitOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; +}; + +#define GET_BUILTIN_LOWERING_CLASSES_DECLARE +#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" +#undef GET_BUILTIN_LOWERING_CLASSES_DECLARE + +} // namespace direct +} // namespace cir diff --git a/clang/utils/TableGen/CIRLoweringEmitter.cpp b/clang/utils/TableGen/CIRLoweringEmitter.cpp index 84b5ceea998e..9b71e9ab597d 100644 --- a/clang/utils/TableGen/CIRLoweringEmitter.cpp +++ b/clang/utils/TableGen/CIRLoweringEmitter.cpp @@ -12,6 +12,7 @@ using namespace llvm; namespace { +std::string ClassDeclaration; std::string ClassDefinitions; std::string ClassList; @@ -19,7 +20,8 @@ void GenerateLowering(const Record *Operation) { using namespace std::string_literals; std::string Name = Operation->getName().str(); std::string LLVMOp = Operation->getValueAsString("llvmOp").str(); - ClassDefinitions += + + ClassDeclaration += "class CIR" + Name + "Lowering : public mlir::OpConversionPattern { @@ -32,15 +34,24 @@ void GenerateLowering(const Record *Operation) { Name + " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) " "const " - "override {"; + "override;" + + R"C++( +}; +)C++"; + + ClassDefinitions += + R"C++(mlir::LogicalResult +CIR)C++" + + Name + "Lowering::matchAndRewrite(cir::" + Name + + R"C++( op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const {)C++"; auto ResultCount = Operation->getValueAsDag("results")->getNumArgs(); if (ResultCount > 0) ClassDefinitions += R"C++( - auto resTy = this->getTypeConverter()->convertType(op.getType());)C++"; + auto resTy = this->getTypeConverter()->convertType(op.getType());)C++"; ClassDefinitions += R"C++( - rewriter.replaceOpWithNewOp(op"; if (ResultCount > 0) @@ -51,9 +62,8 @@ void GenerateLowering(const Record *Operation) { ClassDefinitions += ", adaptor.getOperands()[" + std::to_string(i) + ']'; ClassDefinitions += R"C++(); - return mlir::success(); - } -}; + return mlir::success(); +} )C++"; ClassList += ", CIR" + Name + "Lowering\n"; @@ -69,8 +79,9 @@ void clang::EmitCIRBuiltinsLowering(const RecordKeeper &Records, GenerateLowering(Builtin); } - OS << "#ifdef GET_BUILTIN_LOWERING_CLASSES\n" - << ClassDefinitions << "\n#undef GET_BUILTIN_LOWERING_CLASSES\n#endif\n"; - OS << "#ifdef GET_BUILTIN_LOWERING_LIST\n" - << ClassList << "\n#undef GET_BUILTIN_LOWERING_LIST\n#endif\n"; + OS << "#ifdef GET_BUILTIN_LOWERING_CLASSES_DECLARE\n" + << ClassDeclaration << "\n#endif\n"; + OS << "#ifdef GET_BUILTIN_LOWERING_CLASSES_DEF\n" + << ClassDefinitions << "\n#endif\n"; + OS << "#ifdef GET_BUILTIN_LOWERING_LIST\n" << ClassList << "\n#endif\n"; }