Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 69 additions & 159 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
Expand All @@ -34,6 +35,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
Expand Down Expand Up @@ -333,7 +335,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
} else if (!dyn_cast<mlir::cir::ReturnOp>(elseAfterBody->getTerminator())) {
} else if (!dyn_cast<mlir::cir::ReturnOp>(
elseAfterBody->getTerminator())) {
llvm_unreachable("what are we terminating with?");
}
}
Expand Down Expand Up @@ -889,170 +892,77 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
public:
using OpConversionPattern<mlir::cir::CmpOp>::OpConversionPattern;

mlir::LLVM::ICmpPredicate
convertToICmpPredicate(mlir::cir::CmpOpKind kind) const {
using CIR = mlir::cir::CmpOpKind;
using LLVMICmp = mlir::LLVM::ICmpPredicate;

switch (kind) {
case CIR::eq:
return LLVMICmp::eq;
case CIR::ne:
return LLVMICmp::ne;
case CIR::lt:
return LLVMICmp::ult;
case CIR::le:
return LLVMICmp::ule;
case CIR::gt:
return LLVMICmp::ugt;
case CIR::ge:
return LLVMICmp::uge;
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LLVM::FCmpPredicate
convertToFCmpPredicate(mlir::cir::CmpOpKind kind) const {
using CIR = mlir::cir::CmpOpKind;
using LLVMFCmp = mlir::LLVM::FCmpPredicate;

switch (kind) {
case CIR::eq:
return LLVMFCmp::ueq;
case CIR::ne:
return LLVMFCmp::une;
case CIR::lt:
return LLVMFCmp::ult;
case CIR::le:
return LLVMFCmp::ule;
case CIR::gt:
return LLVMFCmp::ugt;
case CIR::ge:
return LLVMFCmp::uge;
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::CmpOp cmpOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = adaptor.getLhs().getType();
auto i1Type =
mlir::IntegerType::get(getContext(), 1, mlir::IntegerType::Signless);
auto destType = getTypeConverter()->convertType(cmpOp.getType());

switch (adaptor.getKind()) {
case mlir::cir::CmpOpKind::gt: {
if (type.isa<mlir::IntegerType>()) {
mlir::LLVM::ICmpPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::LLVM::ICmpPredicate::ugt;
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else if (type.isa<mlir::FloatType>()) {
auto cmp = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::FCmpPredicateAttr::get(getContext(),
mlir::LLVM::FCmpPredicate::ugt),
adaptor.getLhs(), adaptor.getRhs(),
// TODO(CIR): These fastmath flags need to not be defaulted.
mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {}));
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::ge: {
if (type.isa<mlir::IntegerType>()) {
mlir::LLVM::ICmpPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::LLVM::ICmpPredicate::uge;
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else if (type.isa<mlir::FloatType>()) {
auto cmp = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::FCmpPredicateAttr::get(getContext(),
mlir::LLVM::FCmpPredicate::uge),
adaptor.getLhs(), adaptor.getRhs(),
mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {}));
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::lt: {
if (type.isa<mlir::IntegerType>()) {
mlir::LLVM::ICmpPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::LLVM::ICmpPredicate::ult;
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else if (type.isa<mlir::FloatType>()) {
auto cmp = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::FCmpPredicateAttr::get(getContext(),
mlir::LLVM::FCmpPredicate::ult),
adaptor.getLhs(), adaptor.getRhs(),
mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {}));
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::le: {
if (type.isa<mlir::IntegerType>()) {
mlir::LLVM::ICmpPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::LLVM::ICmpPredicate::ule;
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::ICmpPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else if (type.isa<mlir::FloatType>()) {
auto cmp = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::FCmpPredicateAttr::get(getContext(),
mlir::LLVM::FCmpPredicate::ule),
adaptor.getLhs(), adaptor.getRhs(),
mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {}));
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::eq: {
if (type.isa<mlir::IntegerType>()) {
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::ICmpPredicateAttr::get(getContext(),
mlir::LLVM::ICmpPredicate::eq),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else if (type.isa<mlir::FloatType>()) {
auto cmp = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::FCmpPredicateAttr::get(getContext(),
mlir::LLVM::FCmpPredicate::ueq),
adaptor.getLhs(), adaptor.getRhs(),
mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {}));
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::ne: {
if (type.isa<mlir::IntegerType>()) {
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::ICmpPredicateAttr::get(getContext(),
mlir::LLVM::ICmpPredicate::ne),
adaptor.getLhs(), adaptor.getRhs());

rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else if (type.isa<mlir::FloatType>()) {
auto cmp = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), i1Type,
mlir::LLVM::FCmpPredicateAttr::get(getContext(),
mlir::LLVM::FCmpPredicate::une),
adaptor.getLhs(), adaptor.getRhs(),
mlir::LLVM::FastmathFlagsAttr::get(cmpOp.getContext(), {}));
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, destType,
cmp.getRes());
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
mlir::Value llResult;

// Lower to LLVM comparison op.
if (auto intTy = type.dyn_cast<mlir::IntegerType>()) {
auto kind = convertToICmpPredicate(cmpOp.getKind());
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
auto kind = convertToFCmpPredicate(cmpOp.getKind());
llResult = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else {
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
}

return mlir::LogicalResult::success();
// LLVM comparison ops return i1, but cir::CmpOp returns the same type as
// the LHS value. Since this return value can be used later, we need to
// restore the type with the extension below.
auto llResultTy = getTypeConverter()->convertType(cmpOp.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, llResultTy,
llResult);

return mlir::success();
}
};

Expand Down