Skip to content

Commit

Permalink
[CIR] Lower certain cir.cmp3way operations to LLVM intrinsics (#556)
Browse files Browse the repository at this point in the history
LLVM recently added two families of intrinsics named `llvm.scmp.*` and
`llvm.ucmp.*` that generate potentially better code for three-way
comparison operations. This PR lowers certain `cir.cmp3way` operations
to these intrinsics.

Not all `cir.cmp3way` operations can be lowered to these intrinsics. The
qualifying conditions are: 1) the comparison is between two integers,
and 2) the comparison produces a strong ordering. `cir.cmp3way`
operations that are not qualified are not affected by this PR.

Qualifying `cir.cmp3way` operations may still need some canonicalization
work before lowering. The "canonicalized" form of a qualifying three-way
comparison operation yields -1 for lt, 0 for eq, and 1 for gt. This PR
converts those non-canonicalized but qualifying `cir.cmp3way` operations
to their canonical forms in the LLVM lowering prepare pass.

This PR addresses #514 .
  • Loading branch information
Lancern authored and lanza committed Apr 29, 2024
1 parent ad5e29f commit 12769bd
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 33 deletions.
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::cir::UnaryOpKind::Not, value);
}

mlir::cir::CmpOp createCompare(mlir::Location loc, mlir::cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
return create<mlir::cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::Value createBinop(mlir::Value lhs, mlir::cir::BinOpKind kind,
const llvm::APInt &rhs) {
return create<mlir::cir::BinOp>(
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,18 @@ def CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> {
}];

let hasVerifier = 0;

let extraClassDeclaration = [{
/// Determine whether this three-way comparison produces a strong ordering.
bool isStrongOrdering() {
return getInfo().getOrdering() == mlir::cir::CmpOrdering::Strong;
}

/// Determine whether this three-way comparison compares integral operands.
bool isIntegralComparison() {
return getLhs().getType().isa<mlir::cir::IntType>();
}
}];
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 0 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::ContinueOp>(loc);
}

mlir::cir::CmpOp createCompare(mlir::Location loc, mlir::cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
return create<mlir::cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down
72 changes: 72 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,82 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
return f;
}

static void canonicalizeIntrinsicThreeWayCmp(CIRBaseBuilderTy &builder,
CmpThreeWayOp op) {
auto loc = op->getLoc();
auto cmpInfo = op.getInfo();

if (cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 && cmpInfo.getGt() == 1) {
// The comparison is already in canonicalized form.
return;
}

auto canonicalizedCmpInfo =
mlir::cir::CmpThreeWayInfoAttr::get(builder.getContext(), -1, 0, 1);
mlir::Value result =
builder
.create<mlir::cir::CmpThreeWayOp>(loc, op.getType(), op.getLhs(),
op.getRhs(), canonicalizedCmpInfo)
.getResult();

auto compareAndYield = [&](mlir::Value input, int64_t test,
int64_t yield) -> mlir::Value {
// Create a conditional branch that tests whether `input` is equal to
// `test`. If `input` is equal to `test`, yield `yield`. Otherwise, yield
// `input` as is.
auto testValue = builder.getConstant(
loc, mlir::cir::IntAttr::get(input.getType(), test));
auto yieldValue = builder.getConstant(
loc, mlir::cir::IntAttr::get(input.getType(), yield));
auto eqToTest =
builder.createCompare(loc, mlir::cir::CmpOpKind::eq, input, testValue);
return builder
.create<mlir::cir::TernaryOp>(
loc, eqToTest,
[&](OpBuilder &, Location) {
builder.create<mlir::cir::YieldOp>(loc,
mlir::ValueRange{yieldValue});
},
[&](OpBuilder &, Location) {
builder.create<mlir::cir::YieldOp>(loc, mlir::ValueRange{input});
})
->getResult(0);
};

if (cmpInfo.getLt() != -1)
result = compareAndYield(result, -1, cmpInfo.getLt());

if (cmpInfo.getEq() != 0)
result = compareAndYield(result, 0, cmpInfo.getEq());

if (cmpInfo.getGt() != 1)
result = compareAndYield(result, 1, cmpInfo.getGt());

op.replaceAllUsesWith(result);
op.erase();
}

void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);

if (op.isIntegralComparison() && op.isStrongOrdering()) {
// For three-way comparisons on integral operands that produce strong
// ordering, we can generate potentially better code with the `llvm.scmp.*`
// and `llvm.ucmp.*` intrinsics. Thus we don't replace these comparisons
// here. They will be lowered directly to LLVMIR during the LLVM lowering
// pass.
//
// But we still need to take a step here. `llvm.scmp.*` and `llvm.ucmp.*`
// returns -1, 0, or 1 to represent lt, eq, and gt, which are the
// "canonicalized" result values of three-way comparisons. However,
// `cir.cmp3way` may not produce canonicalized result. We need to
// canonicalize the comparison if necessary. This is what we're doing in
// this special branch.
canonicalizeIntrinsicThreeWayCmp(builder, op);
return;
}

auto loc = op->getLoc();
auto cmpInfo = op.getInfo();

Expand Down
89 changes: 80 additions & 9 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -1984,6 +1985,16 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
}
};

static mlir::LLVM::CallIntrinsicOp
createCallLLVMIntrinsicOp(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc, const llvm::Twine &intrinsicName,
mlir::Type resultTy, mlir::ValueRange operands) {
auto intrinsicNameAttr =
mlir::StringAttr::get(rewriter.getContext(), intrinsicName);
return rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, resultTy, intrinsicNameAttr, operands);
}

static mlir::Value createLLVMBitOp(mlir::Location loc,
const llvm::Twine &llvmIntrinBaseName,
mlir::Type resultTy, mlir::Value operand,
Expand All @@ -1996,21 +2007,19 @@ static mlir::Value createLLVMBitOp(mlir::Location loc,
llvmIntrinBaseName.concat(".i")
.concat(std::to_string(operandIntTy.getWidth()))
.str();
auto llvmIntrinNameAttr =
mlir::StringAttr::get(rewriter.getContext(), llvmIntrinName);

// Note that LLVM intrinsic calls to bit intrinsics have the same type as the
// operand.
mlir::LLVM::CallIntrinsicOp op;
if (poisonZeroInputFlag.has_value()) {
auto poisonZeroInputValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, rewriter.getI1Type(), static_cast<int64_t>(*poisonZeroInputFlag));
op = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, operand.getType(), llvmIntrinNameAttr,
mlir::ValueRange{operand, poisonZeroInputValue});
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(),
{operand, poisonZeroInputValue});
} else {
op = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, operand.getType(), llvmIntrinNameAttr, operand);
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(), operand);
}

mlir::Value result = op->getResult(0);
Expand Down Expand Up @@ -2864,6 +2873,68 @@ class CIRIsConstantOpLowering
}
};

class CIRCmpThreeWayOpLowering
: public mlir::OpConversionPattern<mlir::cir::CmpThreeWayOp> {
public:
using mlir::OpConversionPattern<
mlir::cir::CmpThreeWayOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::CmpThreeWayOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (!op.isIntegralComparison() || !op.isStrongOrdering()) {
op.emitError() << "unsupported three-way comparison type";
return mlir::failure();
}

auto cmpInfo = op.getInfo();
assert(cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 &&
cmpInfo.getGt() == 1);

auto operandTy = op.getLhs().getType().cast<mlir::cir::IntType>();
auto resultTy = op.getType();
auto llvmIntrinsicName = getLLVMIntrinsicName(
operandTy.isSigned(), operandTy.getWidth(), resultTy.getWidth());

rewriter.setInsertionPoint(op);

auto llvmLhs = adaptor.getLhs();
auto llvmRhs = adaptor.getRhs();
auto llvmResultTy = getTypeConverter()->convertType(resultTy);
auto callIntrinsicOp =
createCallLLVMIntrinsicOp(rewriter, op.getLoc(), llvmIntrinsicName,
llvmResultTy, {llvmLhs, llvmRhs});

rewriter.replaceOp(op, callIntrinsicOp);
return mlir::success();
}

private:
static std::string getLLVMIntrinsicName(bool signedCmp, unsigned operandWidth,
unsigned resultWidth) {
// The intrinsic's name takes the form:
// `llvm.<scmp|ucmp>.i<resultWidth>.i<operandWidth>`

std::string result = "llvm.";

if (signedCmp)
result.append("scmp.");
else
result.append("ucmp.");

// Result type part.
result.push_back('i');
result.append(std::to_string(resultWidth));
result.push_back('.');

// Operand type part.
result.push_back('i');
result.append(std::to_string(operandWidth));

return result;
}
};

void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering>(patterns.getContext());
Expand All @@ -2886,8 +2957,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRVectorShuffleVecLowering, CIRStackSaveLowering,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering>(
converter, patterns.getContext());
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CIR/CodeGen/Inputs/std-compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
namespace std {
inline namespace __1 {

#ifdef NON_CANONICAL_CMP_RESULTS

// exposition only
enum class _EqResult : unsigned char {
__equal = 2,
__equiv = __equal,
};

enum class _OrdResult : signed char {
__less = 1,
__greater = 3
};

#else

// exposition only
enum class _EqResult : unsigned char {
__equal = 0,
Expand All @@ -15,6 +30,8 @@ enum class _OrdResult : signed char {
__greater = 1
};

#endif

enum class _NCmpResult : signed char {
__unordered = -127
};
Expand Down
58 changes: 39 additions & 19 deletions clang/test/CIR/CodeGen/three-way-comparison.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-lowering-prepare %s -o %t.cir 2>&1 | FileCheck %s -check-prefix=BEFORE
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-lowering-prepare %s -o %t.cir 2>&1 | FileCheck %s -check-prefix=AFTER
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -DNON_CANONICAL_CMP_RESULTS -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-lowering-prepare %s -o %t.cir 2>&1 | FileCheck %s -check-prefix=NONCANONICAL-BEFORE
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -DNON_CANONICAL_CMP_RESULTS -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-lowering-prepare %s -o %t.cir 2>&1 | FileCheck %s -check-prefix=NONCANONICAL-AFTER

#include "Inputs/std-compare.h"

Expand All @@ -16,25 +18,43 @@ auto three_way_strong(int x, int y) {
// BEFORE: %{{.+}} = cir.cmp3way(%{{.+}} : !s32i, %{{.+}}, #cmp3way_info_strong_ltn1eq0gt1_) : !s8i
// BEFORE: }

// AFTER: cir.func @_Z16three_way_strongii
// AFTER: %[[#LHS:]] = cir.load %{{.+}} : cir.ptr <!s32i>, !s32i
// AFTER-NEXT: %[[#RHS:]] = cir.load %{{.+}} : cir.ptr <!s32i>, !s32i
// AFTER-NEXT: %[[#LT:]] = cir.const(#cir.int<-1> : !s8i) : !s8i
// AFTER-NEXT: %[[#EQ:]] = cir.const(#cir.int<0> : !s8i) : !s8i
// AFTER-NEXT: %[[#GT:]] = cir.const(#cir.int<1> : !s8i) : !s8i
// AFTER-NEXT: %[[#CMP_LT:]] = cir.cmp(lt, %[[#LHS]], %[[#RHS]]) : !s32i, !cir.bool
// AFTER-NEXT: %[[#CMP_EQ:]] = cir.cmp(eq, %[[#LHS]], %[[#RHS]]) : !s32i, !cir.bool
// AFTER-NEXT: %[[#CMP_EQ_RES:]] = cir.ternary(%[[#CMP_EQ]], true {
// AFTER-NEXT: cir.yield %[[#EQ]] : !s8i
// AFTER-NEXT: }, false {
// AFTER-NEXT: cir.yield %[[#GT]] : !s8i
// AFTER-NEXT: }) : (!cir.bool) -> !s8i
// AFTER-NEXT: %{{.+}} = cir.ternary(%[[#CMP_LT]], true {
// AFTER-NEXT: cir.yield %[[#LT]] : !s8i
// AFTER-NEXT: }, false {
// AFTER-NEXT: cir.yield %[[#CMP_EQ_RES]] : !s8i
// AFTER-NEXT: }) : (!cir.bool) -> !s8i
// AFTER: }
// AFTER: cir.func @_Z16three_way_strongii
// AFTER: %{{.+}} = cir.cmp3way(%{{.+}} : !s32i, %{{.+}}, #cmp3way_info_strong_ltn1eq0gt1_) : !s8i
// AFTER: }

// NONCANONICAL-BEFORE: #cmp3way_info_strong_lt1eq2gt3_ = #cir.cmp3way_info<strong, lt = 1, eq = 2, gt = 3>
// NONCANONICAL-BEFORE: cir.func @_Z16three_way_strongii
// NONCANONICAL-BEFORE: %{{.+}} = cir.cmp3way(%{{.+}} : !s32i, %{{.+}}, #cmp3way_info_strong_lt1eq2gt3_) : !s8i
// NONCANONICAL-BEFORE: }

// NONCANONICAL-AFTER: #cmp3way_info_strong_ltn1eq0gt1_ = #cir.cmp3way_info<strong, lt = -1, eq = 0, gt = 1>
// NONCANONICAL-AFTER: cir.func @_Z16three_way_strongii
// NONCANONICAL-AFTER: %[[#CMP3WAY_RESULT:]] = cir.cmp3way(%{{.+}} : !s32i, %{{.+}}, #cmp3way_info_strong_ltn1eq0gt1_) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#NEGONE:]] = cir.const(#cir.int<-1> : !s8i) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#ONE:]] = cir.const(#cir.int<1> : !s8i) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#CMP_TO_NEGONE:]] = cir.cmp(eq, %[[#CMP3WAY_RESULT]], %[[#NEGONE]]) : !s8i, !cir.bool
// NONCANONICAL-AFTER-NEXT: %[[#A:]] = cir.ternary(%[[#CMP_TO_NEGONE]], true {
// NONCANONICAL-AFTER-NEXT: cir.yield %[[#ONE]] : !s8i
// NONCANONICAL-AFTER-NEXT: }, false {
// NONCANONICAL-AFTER-NEXT: cir.yield %[[#CMP3WAY_RESULT]] : !s8i
// NONCANONICAL-AFTER-NEXT: }) : (!cir.bool) -> !s8i
// NONCANONICAL-AFTER-NEXT: %[[#ZERO:]] = cir.const(#cir.int<0> : !s8i) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#TWO:]] = cir.const(#cir.int<2> : !s8i) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#CMP_TO_ZERO:]] = cir.cmp(eq, %[[#A]], %[[#ZERO]]) : !s8i, !cir.bool
// NONCANONICAL-AFTER-NEXT: %[[#B:]] = cir.ternary(%[[#CMP_TO_ZERO]], true {
// NONCANONICAL-AFTER-NEXT: cir.yield %[[#TWO]] : !s8i
// NONCANONICAL-AFTER-NEXT: }, false {
// NONCANONICAL-AFTER-NEXT: cir.yield %[[#A]] : !s8i
// NONCANONICAL-AFTER-NEXT: }) : (!cir.bool) -> !s8i
// NONCANONICAL-AFTER-NEXT: %[[#ONE2:]] = cir.const(#cir.int<1> : !s8i) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#THREE:]] = cir.const(#cir.int<3> : !s8i) : !s8i
// NONCANONICAL-AFTER-NEXT: %[[#CMP_TO_ONE:]] = cir.cmp(eq, %[[#B]], %[[#ONE2]]) : !s8i, !cir.bool
// NONCANONICAL-AFTER-NEXT: %{{.+}} = cir.ternary(%[[#CMP_TO_ONE]], true {
// NONCANONICAL-AFTER-NEXT: cir.yield %[[#THREE]] : !s8i
// NONCANONICAL-AFTER-NEXT: }, false {
// NONCANONICAL-AFTER-NEXT: cir.yield %[[#B]] : !s8i
// NONCANONICAL-AFTER-NEXT: }) : (!cir.bool) -> !s8i
// NONCANONICAL-AFTER: }

auto three_way_weak(float x, float y) {
return x <=> y;
Expand Down
Loading

0 comments on commit 12769bd

Please sign in to comment.