Skip to content

Commit

Permalink
[CIR][CIRGen] Add CIRGen for binary fp2fp builtin operations (llvm#616)
Browse files Browse the repository at this point in the history
This PR adds the following operations for the builtin binary fp2fp
functions:

  - `cir.copysign` for `__builtin_copysign`;
  - `cir.fmax` for `__builtin_fmax`;
  - `cir.fmin` for `__builtin_fmin`;
  - `cir.fmod` for `__builtin_fmod`;
  - `cir.pow` for `__builtin_pow`.

This PR also includes CIRGen support for these new operations.
  • Loading branch information
Lancern authored and lanza committed Oct 12, 2024
1 parent 966a733 commit 01dd7d8
Show file tree
Hide file tree
Showing 6 changed files with 501 additions and 12 deletions.
20 changes: 20 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,26 @@ def SinOp : UnaryFPToFPBuiltinOp<"sin">;
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc">;

class BinaryFPToFPBuiltinOp<string mnemonic>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let summary = [{
libc builtin equivalent ignoring floating-point exceptions and errno.
}];

let arguments = (ins CIR_AnyFloat:$lhs, CIR_AnyFloat:$rhs);
let results = (outs CIR_AnyFloat:$result);

let assemblyFormat = [{
$lhs `,` $rhs `:` qualified(type($lhs)) attr-dict
}];
}

def CopysignOp : BinaryFPToFPBuiltinOp<"copysign">;
def FMaxOp : BinaryFPToFPBuiltinOp<"fmax">;
def FMinOp : BinaryFPToFPBuiltinOp<"fmin">;
def FModOp : BinaryFPToFPBuiltinOp<"fmod">;
def PowOp : BinaryFPToFPBuiltinOp<"pow">;

//===----------------------------------------------------------------------===//
// Branch Probability Operations
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 49 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,36 @@ static RValue buildUnaryFPBuiltin(CIRGenFunction &CGF, const CallExpr &E) {
return RValue::get(Call->getResult(0));
}

template <typename Op>
static RValue buildBinaryFPBuiltin(CIRGenFunction &CGF, const CallExpr &E) {
auto Arg0 = CGF.buildScalarExpr(E.getArg(0));
auto Arg1 = CGF.buildScalarExpr(E.getArg(1));

auto Loc = CGF.getLoc(E.getExprLoc());
auto Ty = CGF.ConvertType(E.getType());
auto Call = CGF.getBuilder().create<Op>(Loc, Ty, Arg0, Arg1);

return RValue::get(Call->getResult(0));
}

template <typename Op>
static mlir::Value buildBinaryMaybeConstrainedFPBuiltin(CIRGenFunction &CGF,
const CallExpr &E) {
auto Arg0 = CGF.buildScalarExpr(E.getArg(0));
auto Arg1 = CGF.buildScalarExpr(E.getArg(1));

auto Loc = CGF.getLoc(E.getExprLoc());
auto Ty = CGF.ConvertType(E.getType());

if (CGF.getBuilder().getIsFPConstrained()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, &E);
llvm_unreachable("constrained FP operations are NYI");
} else {
auto Call = CGF.getBuilder().create<Op>(Loc, Ty, Arg0, Arg1);
return Call->getResult(0);
}
}

template <typename Op>
static RValue
buildBuiltinBitOp(CIRGenFunction &CGF, const CallExpr *E,
Expand Down Expand Up @@ -290,8 +320,10 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BIcopysignl:
case Builtin::BI__builtin_copysign:
case Builtin::BI__builtin_copysignf:
case Builtin::BI__builtin_copysignf16:
case Builtin::BI__builtin_copysignl:
return buildBinaryFPBuiltin<mlir::cir::CopysignOp>(*this, *E);

case Builtin::BI__builtin_copysignf16:
case Builtin::BI__builtin_copysignf128:
llvm_unreachable("NYI");

Expand Down Expand Up @@ -360,8 +392,11 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BIfmaxl:
case Builtin::BI__builtin_fmax:
case Builtin::BI__builtin_fmaxf:
case Builtin::BI__builtin_fmaxf16:
case Builtin::BI__builtin_fmaxl:
return RValue::get(
buildBinaryMaybeConstrainedFPBuiltin<mlir::cir::FMaxOp>(*this, *E));

case Builtin::BI__builtin_fmaxf16:
case Builtin::BI__builtin_fmaxf128:
llvm_unreachable("NYI");

Expand All @@ -370,8 +405,11 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BIfminl:
case Builtin::BI__builtin_fmin:
case Builtin::BI__builtin_fminf:
case Builtin::BI__builtin_fminf16:
case Builtin::BI__builtin_fminl:
return RValue::get(
buildBinaryMaybeConstrainedFPBuiltin<mlir::cir::FMinOp>(*this, *E));

case Builtin::BI__builtin_fminf16:
case Builtin::BI__builtin_fminf128:
llvm_unreachable("NYI");

Expand All @@ -382,11 +420,12 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BIfmodl:
case Builtin::BI__builtin_fmod:
case Builtin::BI__builtin_fmodf:
case Builtin::BI__builtin_fmodf16:
case Builtin::BI__builtin_fmodl:
case Builtin::BI__builtin_fmodf128: {
return buildBinaryFPBuiltin<mlir::cir::FModOp>(*this, *E);

case Builtin::BI__builtin_fmodf16:
case Builtin::BI__builtin_fmodf128:
llvm_unreachable("NYI");
}

case Builtin::BIlog:
case Builtin::BIlogf:
Expand Down Expand Up @@ -432,8 +471,11 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
case Builtin::BIpowl:
case Builtin::BI__builtin_pow:
case Builtin::BI__builtin_powf:
case Builtin::BI__builtin_powf16:
case Builtin::BI__builtin_powl:
return RValue::get(
buildBinaryMaybeConstrainedFPBuiltin<mlir::cir::PowOp>(*this, *E));

case Builtin::BI__builtin_powf16:
case Builtin::BI__builtin_powf128:
llvm_unreachable("NYI");

Expand Down
53 changes: 51 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void lowerIterEndOp(IterEndOp op);
void lowerArrayDtor(ArrayDtor op);
void lowerArrayCtor(ArrayCtor op);
void lowerFModOp(FModOp op);
void lowerPowOp(PowOp op);

/// Build the function that initializes the specified global
FuncOp buildCXXGlobalVarDeclInitFunc(GlobalOp op);
Expand Down Expand Up @@ -625,6 +627,49 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
op.erase();
}

static void lowerBinaryFPToFPBuiltinOp(LoweringPreparePass &pass,
mlir::Operation *op,
llvm::StringRef floatRtFuncName,
llvm::StringRef doubleRtFuncName,
llvm::StringRef longDoubleRtFuncName) {
mlir::Type ty = op->getResult(0).getType();

llvm::StringRef rtFuncName;
if (ty.isa<mlir::cir::SingleType>())
rtFuncName = floatRtFuncName;
else if (ty.isa<mlir::cir::DoubleType>())
rtFuncName = doubleRtFuncName;
else if (ty.isa<mlir::cir::LongDoubleType>())
rtFuncName = longDoubleRtFuncName;
else
llvm_unreachable("unknown binary fp2fp builtin operand type");

CIRBaseBuilderTy builder(*pass.theModule.getContext());
builder.setInsertionPointToStart(pass.theModule.getBody());

auto rtFuncTy = mlir::cir::FuncType::get({ty, ty}, ty);
FuncOp rtFunc =
pass.buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);

builder.setInsertionPointAfter(op);
auto call = builder.create<mlir::cir::CallOp>(op->getLoc(), rtFunc,
mlir::ValueRange{lhs, rhs});

op->replaceAllUsesWith(call);
op->erase();
}

void LoweringPreparePass::lowerFModOp(FModOp op) {
lowerBinaryFPToFPBuiltinOp(*this, op, "fmodf", "fmod", "fmodl");
}

void LoweringPreparePass::lowerPowOp(PowOp op) {
lowerBinaryFPToFPBuiltinOp(*this, op, "powf", "pow", "powl");
}

void LoweringPreparePass::runOnOp(Operation *op) {
if (auto threeWayCmp = dyn_cast<CmpThreeWayOp>(op)) {
lowerThreeWayCmpOp(threeWayCmp);
Expand All @@ -650,6 +695,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
} else if (auto globalDtor = fnOp.getGlobalDtorAttr()) {
globalDtorList.push_back(globalDtor);
}
} else if (auto fmodOp = dyn_cast<FModOp>(op)) {
lowerFModOp(fmodOp);
} else if (auto powOp = dyn_cast<PowOp>(op)) {
lowerPowOp(powOp);
}
}

Expand All @@ -663,8 +712,8 @@ void LoweringPreparePass::runOnOperation() {
SmallVector<Operation *> opsToTransform;
op->walk([&](Operation *op) {
if (isa<CmpThreeWayOp, VAArgOp, GlobalOp, DynamicCastOp, StdFindOp,
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp>(
op))
IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor, mlir::cir::FuncOp,
FModOp, PowOp>(op))
opsToTransform.push_back(op);
});

Expand Down
28 changes: 27 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3163,6 +3163,31 @@ class CIRCmpThreeWayOpLowering
}
};

template <typename CIROp, typename LLVMOp>
class CIRBinaryFPToFPBuiltinOpLowering
: public mlir::OpConversionPattern<CIROp> {
public:
using mlir::OpConversionPattern<CIROp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(CIROp op,
typename mlir::OpConversionPattern<CIROp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resTy = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<LLVMOp>(op, resTy, adaptor.getLhs(),
adaptor.getRhs());
return mlir::success();
}
};

using CIRCopysignOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::CopysignOp,
mlir::LLVM::CopySignOp>;
using CIRFMaxOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::FMaxOp, mlir::LLVM::MaxNumOp>;
using CIRFMinOpLowering =
CIRBinaryFPToFPBuiltinOpLowering<mlir::cir::FMinOp, mlir::LLVM::MinNumOp>;

void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering>(patterns.getContext());
Expand All @@ -3187,7 +3212,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering>(converter, patterns.getContext());
CIRCmpThreeWayOpLowering, CIRCopysignOpLowering, CIRFMaxOpLowering,
CIRFMinOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
Loading

0 comments on commit 01dd7d8

Please sign in to comment.