Skip to content

Commit 214f3fb

Browse files
sitio-coutolanza
authored andcommitted
[CIR][Lowering] Add support for signed comparisons
Updates CIR's CmpOp lowering to use CIR's custom cir::IntType, allowing it to handle signed comparisons. ghstack-source-id: e4709315db1a39853fe978ef9771ab727ad9f9d7 Pull Request resolved: llvm/clangir#106
1 parent f51e982 commit 214f3fb

File tree

4 files changed

+43
-43
lines changed

4 files changed

+43
-43
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
219219
mlir::cir::IntAttr::get(castOp.getSrc().getType(), 0));
220220
rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>(
221221
castOp, mlir::cir::BoolType::get(getContext()),
222-
mlir::cir::CmpOpKind::ne, src, zero);
222+
mlir::cir::CmpOpKind::ne, castOp.getSrc(), zero);
223223
break;
224224
}
225225
case mlir::cir::CastKind::integral: {
@@ -902,8 +902,8 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
902902
public:
903903
using OpConversionPattern<mlir::cir::CmpOp>::OpConversionPattern;
904904

905-
mlir::LLVM::ICmpPredicate
906-
convertToICmpPredicate(mlir::cir::CmpOpKind kind) const {
905+
mlir::LLVM::ICmpPredicate convertToICmpPredicate(mlir::cir::CmpOpKind kind,
906+
bool isSigned) const {
907907
using CIR = mlir::cir::CmpOpKind;
908908
using LLVMICmp = mlir::LLVM::ICmpPredicate;
909909

@@ -913,13 +913,13 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
913913
case CIR::ne:
914914
return LLVMICmp::ne;
915915
case CIR::lt:
916-
return LLVMICmp::ult;
916+
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
917917
case CIR::le:
918-
return LLVMICmp::ule;
918+
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
919919
case CIR::gt:
920-
return LLVMICmp::ugt;
920+
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
921921
case CIR::ge:
922-
return LLVMICmp::uge;
922+
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
923923
}
924924
llvm_unreachable("Unknown CmpOpKind");
925925
}
@@ -949,12 +949,12 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
949949
mlir::LogicalResult
950950
matchAndRewrite(mlir::cir::CmpOp cmpOp, OpAdaptor adaptor,
951951
mlir::ConversionPatternRewriter &rewriter) const override {
952-
auto type = adaptor.getLhs().getType();
952+
auto type = cmpOp.getLhs().getType();
953953
mlir::Value llResult;
954954

955955
// Lower to LLVM comparison op.
956-
if (auto intTy = type.dyn_cast<mlir::IntegerType>()) {
957-
auto kind = convertToICmpPredicate(cmpOp.getKind());
956+
if (auto intTy = type.dyn_cast<mlir::cir::IntType>()) {
957+
auto kind = convertToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
958958
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
959959
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
960960
} else if (type.isa<mlir::FloatType>()) {

clang/test/CIR/Lowering/cmp.cir

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
11
// RUN: cir-tool %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR
22
// RUN: cir-tool %s -cir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3-
3+
!s32i = !cir.int<s, 32>
44
module {
55
cir.func @foo() {
6-
%0 = cir.alloca i32, cir.ptr <i32>, ["a"] {alignment = 4 : i64}
7-
%1 = cir.alloca i32, cir.ptr <i32>, ["b"] {alignment = 4 : i64}
6+
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["a"] {alignment = 4 : i64}
7+
%1 = cir.alloca !s32i, cir.ptr <!s32i>, ["b"] {alignment = 4 : i64}
88
%2 = cir.alloca f32, cir.ptr <f32>, ["c"] {alignment = 4 : i64}
99
%3 = cir.alloca f32, cir.ptr <f32>, ["d"] {alignment = 4 : i64}
1010
%4 = cir.alloca !cir.bool, cir.ptr <!cir.bool>, ["e"] {alignment = 1 : i64}
11-
%5 = cir.load %0 : cir.ptr <i32>, i32
12-
%6 = cir.load %1 : cir.ptr <i32>, i32
13-
%7 = cir.cmp(gt, %5, %6) : i32, !cir.bool
14-
%8 = cir.load %0 : cir.ptr <i32>, i32
15-
%9 = cir.load %1 : cir.ptr <i32>, i32
16-
%10 = cir.cmp(eq, %8, %9) : i32, !cir.bool
17-
%11 = cir.load %0 : cir.ptr <i32>, i32
18-
%12 = cir.load %1 : cir.ptr <i32>, i32
19-
%13 = cir.cmp(lt, %11, %12) : i32, !cir.bool
20-
%14 = cir.load %0 : cir.ptr <i32>, i32
21-
%15 = cir.load %1 : cir.ptr <i32>, i32
22-
%16 = cir.cmp(ge, %14, %15) : i32, !cir.bool
23-
%17 = cir.load %0 : cir.ptr <i32>, i32
24-
%18 = cir.load %1 : cir.ptr <i32>, i32
25-
%19 = cir.cmp(ne, %17, %18) : i32, !cir.bool
26-
%20 = cir.load %0 : cir.ptr <i32>, i32
27-
%21 = cir.load %1 : cir.ptr <i32>, i32
28-
%22 = cir.cmp(le, %20, %21) : i32, !cir.bool
11+
%5 = cir.load %0 : cir.ptr <!s32i>, !s32i
12+
%6 = cir.load %1 : cir.ptr <!s32i>, !s32i
13+
%7 = cir.cmp(gt, %5, %6) : !s32i, !cir.bool
14+
%8 = cir.load %0 : cir.ptr <!s32i>, !s32i
15+
%9 = cir.load %1 : cir.ptr <!s32i>, !s32i
16+
%10 = cir.cmp(eq, %8, %9) : !s32i, !cir.bool
17+
%11 = cir.load %0 : cir.ptr <!s32i>, !s32i
18+
%12 = cir.load %1 : cir.ptr <!s32i>, !s32i
19+
%13 = cir.cmp(lt, %11, %12) : !s32i, !cir.bool
20+
%14 = cir.load %0 : cir.ptr <!s32i>, !s32i
21+
%15 = cir.load %1 : cir.ptr <!s32i>, !s32i
22+
%16 = cir.cmp(ge, %14, %15) : !s32i, !cir.bool
23+
%17 = cir.load %0 : cir.ptr <!s32i>, !s32i
24+
%18 = cir.load %1 : cir.ptr <!s32i>, !s32i
25+
%19 = cir.cmp(ne, %17, %18) : !s32i, !cir.bool
26+
%20 = cir.load %0 : cir.ptr <!s32i>, !s32i
27+
%21 = cir.load %1 : cir.ptr <!s32i>, !s32i
28+
%22 = cir.cmp(le, %20, %21) : !s32i, !cir.bool
2929
%23 = cir.load %2 : cir.ptr <f32>, f32
3030
%24 = cir.load %3 : cir.ptr <f32>, f32
3131
%25 = cir.cmp(gt, %23, %24) : f32, !cir.bool
@@ -48,25 +48,25 @@ module {
4848
}
4949
}
5050

51-
// MLIR: = llvm.icmp "ugt"
51+
// MLIR: = llvm.icmp "sgt"
5252
// MLIR: = llvm.icmp "eq"
53-
// MLIR: = llvm.icmp "ult"
54-
// MLIR: = llvm.icmp "uge"
53+
// MLIR: = llvm.icmp "slt"
54+
// MLIR: = llvm.icmp "sge"
5555
// MLIR: = llvm.icmp "ne"
56-
// MLIR: = llvm.icmp "ule"
56+
// MLIR: = llvm.icmp "sle"
5757
// MLIR: = llvm.fcmp "ugt"
5858
// MLIR: = llvm.fcmp "ueq"
5959
// MLIR: = llvm.fcmp "ult"
6060
// MLIR: = llvm.fcmp "uge"
6161
// MLIR: = llvm.fcmp "une"
6262
// MLIR: = llvm.fcmp "ule"
6363

64-
// LLVM: icmp ugt i32
64+
// LLVM: icmp sgt i32
6565
// LLVM: icmp eq i32
66-
// LLVM: icmp ult i32
67-
// LLVM: icmp uge i32
66+
// LLVM: icmp slt i32
67+
// LLVM: icmp sge i32
6868
// LLVM: icmp ne i32
69-
// LLVM: icmp ule i32
69+
// LLVM: icmp sle i32
7070
// LLVM: fcmp ugt float
7171
// LLVM: fcmp ueq float
7272
// LLVM: fcmp ult float

clang/test/CIR/Lowering/dot.cir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ module {
8383
// MLIR-NEXT: ^bb2: // 2 preds: ^bb1, ^bb6
8484
// MLIR-NEXT: %14 = llvm.load %12 : !llvm.ptr
8585
// MLIR-NEXT: %15 = llvm.load %5 : !llvm.ptr
86-
// MLIR-NEXT: %16 = llvm.icmp "ult" %14, %15 : i32
86+
// MLIR-NEXT: %16 = llvm.icmp "slt" %14, %15 : i32
8787
// MLIR-NEXT: %17 = llvm.zext %16 : i1 to i32
8888
// MLIR-NEXT: %18 = llvm.mlir.constant(0 : i32) : i32
8989
// MLIR-NEXT: %19 = llvm.icmp "ne" %17, %18 : i32
@@ -144,7 +144,7 @@ module {
144144
// LLVM-NEXT: 11: ; preds = %24, %9
145145
// LLVM-NEXT: %12 = load i32, ptr %10, align 4
146146
// LLVM-NEXT: %13 = load i32, ptr %6, align 4
147-
// LLVM-NEXT: %14 = icmp ult i32 %12, %13
147+
// LLVM-NEXT: %14 = icmp slt i32 %12, %13
148148
// LLVM-NEXT: %15 = zext i1 %14 to i32
149149
// LLVM-NEXT: %16 = icmp ne i32 %15, 0
150150
// LLVM-NEXT: %17 = zext i1 %16 to i8

clang/test/CIR/Lowering/for.cir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ module {
3939
// MLIR-NEXT: ^bb1: // 2 preds: ^bb0, ^bb5
4040
// MLIR-NEXT: %3 = llvm.load %1 : !llvm.ptr
4141
// MLIR-NEXT: %4 = llvm.mlir.constant(10 : i32) : i32
42-
// MLIR-NEXT: %5 = llvm.icmp "ult" %3, %4 : i32
42+
// MLIR-NEXT: %5 = llvm.icmp "slt" %3, %4 : i32
4343
// MLIR-NEXT: %6 = llvm.zext %5 : i1 to i32
4444
// MLIR-NEXT: %7 = llvm.mlir.constant(0 : i32) : i32
4545
// MLIR-NEXT: %8 = llvm.icmp "ne" %6, %7 : i32
@@ -70,7 +70,7 @@ module {
7070
// LLVM-EMPTY:
7171
// LLVM-NEXT: 2:
7272
// LLVM-NEXT: %3 = load i32, ptr %1, align 4
73-
// LLVM-NEXT: %4 = icmp ult i32 %3, 10
73+
// LLVM-NEXT: %4 = icmp slt i32 %3, 10
7474
// LLVM-NEXT: %5 = zext i1 %4 to i32
7575
// LLVM-NEXT: %6 = icmp ne i32 %5, 0
7676
// LLVM-NEXT: %7 = zext i1 %6 to i8

0 commit comments

Comments
 (0)