diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 8c020922f3b1..3b5f0c1ded6b 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -34,6 +35,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" @@ -333,7 +335,8 @@ class CIRIfLowering : public mlir::OpConversionPattern { dyn_cast(elseAfterBody->getTerminator())) { rewriter.replaceOpWithNewOp( elseYieldOp, elseYieldOp.getArgs(), continueBlock); - } else if (!dyn_cast(elseAfterBody->getTerminator())) { + } else if (!dyn_cast( + elseAfterBody->getTerminator())) { llvm_unreachable("what are we terminating with?"); } } @@ -889,170 +892,77 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + mlir::LLVM::ICmpPredicate + convertToICmpPredicate(mlir::cir::CmpOpKind kind) const { + using CIR = mlir::cir::CmpOpKind; + using LLVMICmp = mlir::LLVM::ICmpPredicate; + + switch (kind) { + case CIR::eq: + return LLVMICmp::eq; + case CIR::ne: + return LLVMICmp::ne; + case CIR::lt: + return LLVMICmp::ult; + case CIR::le: + return LLVMICmp::ule; + case CIR::gt: + return LLVMICmp::ugt; + case CIR::ge: + return LLVMICmp::uge; + } + llvm_unreachable("Unknown CmpOpKind"); + } + + mlir::LLVM::FCmpPredicate + convertToFCmpPredicate(mlir::cir::CmpOpKind kind) const { + using CIR = mlir::cir::CmpOpKind; + using LLVMFCmp = mlir::LLVM::FCmpPredicate; + + switch (kind) { + case CIR::eq: + return LLVMFCmp::ueq; + case CIR::ne: + return LLVMFCmp::une; + case CIR::lt: + return LLVMFCmp::ult; + case CIR::le: + return LLVMFCmp::ule; + case CIR::gt: + return LLVMFCmp::ugt; + case CIR::ge: + return LLVMFCmp::uge; + } + llvm_unreachable("Unknown CmpOpKind"); + } + mlir::LogicalResult matchAndRewrite(mlir::cir::CmpOp cmpOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto type = adaptor.getLhs().getType(); - auto i1Type = - mlir::IntegerType::get(getContext(), 1, mlir::IntegerType::Signless); - auto destType = getTypeConverter()->convertType(cmpOp.getType()); - - switch (adaptor.getKind()) { - case mlir::cir::CmpOpKind::gt: { - if (type.isa()) { - mlir::LLVM::ICmpPredicate cmpIType; - if (!type.isSignlessInteger()) - llvm_unreachable("integer type not supported in CIR yet"); - cmpIType = mlir::LLVM::ICmpPredicate::ugt; - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType), - adaptor.getLhs(), adaptor.getRhs()); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::FCmpPredicateAttr::get(getContext(), - mlir::LLVM::FCmpPredicate::ugt), - adaptor.getLhs(), adaptor.getRhs(), - // TODO(CIR): These fastmath flags need to not be defaulted. - mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {})); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else { - llvm_unreachable("Unknown Operand Type"); - } - break; - } - case mlir::cir::CmpOpKind::ge: { - if (type.isa()) { - mlir::LLVM::ICmpPredicate cmpIType; - if (!type.isSignlessInteger()) - llvm_unreachable("integer type not supported in CIR yet"); - cmpIType = mlir::LLVM::ICmpPredicate::uge; - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType), - adaptor.getLhs(), adaptor.getRhs()); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::FCmpPredicateAttr::get(getContext(), - mlir::LLVM::FCmpPredicate::uge), - adaptor.getLhs(), adaptor.getRhs(), - mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {})); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else { - llvm_unreachable("Unknown Operand Type"); - } - break; - } - case mlir::cir::CmpOpKind::lt: { - if (type.isa()) { - mlir::LLVM::ICmpPredicate cmpIType; - if (!type.isSignlessInteger()) - llvm_unreachable("integer type not supported in CIR yet"); - cmpIType = mlir::LLVM::ICmpPredicate::ult; - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType), - adaptor.getLhs(), adaptor.getRhs()); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::FCmpPredicateAttr::get(getContext(), - mlir::LLVM::FCmpPredicate::ult), - adaptor.getLhs(), adaptor.getRhs(), - mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {})); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else { - llvm_unreachable("Unknown Operand Type"); - } - break; - } - case mlir::cir::CmpOpKind::le: { - if (type.isa()) { - mlir::LLVM::ICmpPredicate cmpIType; - if (!type.isSignlessInteger()) - llvm_unreachable("integer type not supported in CIR yet"); - cmpIType = mlir::LLVM::ICmpPredicate::ule; - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType), - adaptor.getLhs(), adaptor.getRhs()); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::FCmpPredicateAttr::get(getContext(), - mlir::LLVM::FCmpPredicate::ule), - adaptor.getLhs(), adaptor.getRhs(), - mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {})); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else { - llvm_unreachable("Unknown Operand Type"); - } - break; - } - case mlir::cir::CmpOpKind::eq: { - if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::ICmpPredicateAttr::get(getContext(), - mlir::LLVM::ICmpPredicate::eq), - adaptor.getLhs(), adaptor.getRhs()); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::FCmpPredicateAttr::get(getContext(), - mlir::LLVM::FCmpPredicate::ueq), - adaptor.getLhs(), adaptor.getRhs(), - mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {})); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else { - llvm_unreachable("Unknown Operand Type"); - } - break; - } - case mlir::cir::CmpOpKind::ne: { - if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::ICmpPredicateAttr::get(getContext(), - mlir::LLVM::ICmpPredicate::ne), - adaptor.getLhs(), adaptor.getRhs()); - - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else if (type.isa()) { - auto cmp = rewriter.create( - cmpOp.getLoc(), i1Type, - mlir::LLVM::FCmpPredicateAttr::get(getContext(), - mlir::LLVM::FCmpPredicate::une), - adaptor.getLhs(), adaptor.getRhs(), - mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {})); - rewriter.replaceOpWithNewOp(cmpOp, destType, - cmp.getRes()); - } else { - llvm_unreachable("Unknown Operand Type"); - } - break; - } + mlir::Value llResult; + + // Lower to LLVM comparison op. + if (auto intTy = type.dyn_cast()) { + auto kind = convertToICmpPredicate(cmpOp.getKind()); + llResult = rewriter.create( + cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + } else if (type.isa()) { + auto kind = convertToFCmpPredicate(cmpOp.getKind()); + llResult = rewriter.create( + cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); + } else { + return cmpOp.emitError() << "unsupported type for CmpOp: " << type; } - return mlir::LogicalResult::success(); + // LLVM comparison ops return i1, but cir::CmpOp returns the same type as + // the LHS value. Since this return value can be used later, we need to + // restore the type with the extension below. + auto llResultTy = getTypeConverter()->convertType(cmpOp.getType()); + rewriter.replaceOpWithNewOp(cmpOp, llResultTy, + llResult); + + return mlir::success(); } };