Skip to content

Commit 13ec913

Browse files
committed
[InstCombine] Recognize ((x * y) s/ x) !=/== y as an signed multiplication overflow check (PR48769)
We already had support for it's unsigned variant, so simply extend it to also handle the signed variant. Fixes https://bugs.llvm.org/show_bug.cgi?id=48769
1 parent 632eb20 commit 13ec913

File tree

4 files changed

+71
-66
lines changed

4 files changed

+71
-66
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3672,19 +3672,22 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
36723672

36733673
/// Fold
36743674
/// (-1 u/ x) u< y
3675-
/// ((x * y) u/ x) != y
3675+
/// ((x * y) ?/ x) != y
36763676
/// to
3677-
/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit
3677+
/// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit
36783678
/// Note that the comparison is commutative, while inverted (u>=, ==) predicate
36793679
/// will mean that we are looking for the opposite answer.
3680-
Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
3680+
Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
36813681
ICmpInst::Predicate Pred;
36823682
Value *X, *Y;
36833683
Instruction *Mul;
3684+
Instruction *Div;
36843685
bool NeedNegation;
36853686
// Look for: (-1 u/ x) u</u>= y
36863687
if (!I.isEquality() &&
3687-
match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
3688+
match(&I, m_c_ICmp(Pred,
3689+
m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
3690+
m_Instruction(Div)),
36883691
m_Value(Y)))) {
36893692
Mul = nullptr;
36903693

@@ -3699,13 +3702,16 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
36993702
default:
37003703
return nullptr; // Wrong predicate.
37013704
}
3702-
} else // Look for: ((x * y) u/ x) !=/== y
3705+
} else // Look for: ((x * y) / x) !=/== y
37033706
if (I.isEquality() &&
3704-
match(&I, m_c_ICmp(Pred, m_Value(Y),
3705-
m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y),
3707+
match(&I,
3708+
m_c_ICmp(Pred, m_Value(Y),
3709+
m_CombineAnd(
3710+
m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y),
37063711
m_Value(X)),
37073712
m_Instruction(Mul)),
3708-
m_Deferred(X)))))) {
3713+
m_Deferred(X))),
3714+
m_Instruction(Div))))) {
37093715
NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ;
37103716
} else
37113717
return nullptr;
@@ -3717,19 +3723,22 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
37173723
if (MulHadOtherUses)
37183724
Builder.SetInsertPoint(Mul);
37193725

3720-
Function *F = Intrinsic::getDeclaration(
3721-
I.getModule(), Intrinsic::umul_with_overflow, X->getType());
3722-
CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul");
3726+
Function *F = Intrinsic::getDeclaration(I.getModule(),
3727+
Div->getOpcode() == Instruction::UDiv
3728+
? Intrinsic::umul_with_overflow
3729+
: Intrinsic::smul_with_overflow,
3730+
X->getType());
3731+
CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul");
37233732

37243733
// If the multiplication was used elsewhere, to ensure that we don't leave
37253734
// "duplicate" instructions, replace uses of that original multiplication
37263735
// with the multiplication result from the with.overflow intrinsic.
37273736
if (MulHadOtherUses)
3728-
replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val"));
3737+
replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val"));
37293738

3730-
Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov");
3739+
Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov");
37313740
if (NeedNegation) // This technically increases instruction count.
3732-
Res = Builder.CreateNot(Res, "umul.not.ov");
3741+
Res = Builder.CreateNot(Res, "mul.not.ov");
37333742

37343743
// If we replaced the mul, erase it. Do this after all uses of Builder,
37353744
// as the mul is used as insertion point.
@@ -4126,7 +4135,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
41264135
}
41274136
}
41284137

4129-
if (Value *V = foldUnsignedMultiplicationOverflowCheck(I))
4138+
if (Value *V = foldMultiplicationOverflowCheck(I))
41304139
return replaceInstUsesWith(I, V);
41314140

41324141
if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
656656
Instruction *foldSignBitTest(ICmpInst &I);
657657
Instruction *foldICmpWithZero(ICmpInst &Cmp);
658658

659-
Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp);
659+
Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp);
660660

661661
Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select,
662662
ConstantInt *C);

llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
define i1 @t0_basic(i8 %x, i8 %y) {
1010
; CHECK-LABEL: @t0_basic(
11-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
12-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
13-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
14-
; CHECK-NEXT: ret i1 [[R]]
11+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
12+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
13+
; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
14+
; CHECK-NEXT: ret i1 [[MUL_NOT_OV]]
1515
;
1616
%t0 = mul i8 %x, %y
1717
%t1 = sdiv i8 %t0, %x
@@ -21,10 +21,10 @@ define i1 @t0_basic(i8 %x, i8 %y) {
2121

2222
define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) {
2323
; CHECK-LABEL: @t1_vec(
24-
; CHECK-NEXT: [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]]
25-
; CHECK-NEXT: [[T1:%.*]] = sdiv <2 x i8> [[T0]], [[X]]
26-
; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[T1]], [[Y]]
27-
; CHECK-NEXT: ret <2 x i1> [[R]]
24+
; CHECK-NEXT: [[MUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.smul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]])
25+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[MUL]], 1
26+
; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor <2 x i1> [[MUL_OV]], <i1 true, i1 true>
27+
; CHECK-NEXT: ret <2 x i1> [[MUL_NOT_OV]]
2828
;
2929
%t0 = mul <2 x i8> %x, %y
3030
%t1 = sdiv <2 x i8> %t0, %x
@@ -37,10 +37,10 @@ declare i8 @gen8()
3737
define i1 @t2_commutative(i8 %x) {
3838
; CHECK-LABEL: @t2_commutative(
3939
; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8()
40-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
41-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
42-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
43-
; CHECK-NEXT: ret i1 [[R]]
40+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
41+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
42+
; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
43+
; CHECK-NEXT: ret i1 [[MUL_NOT_OV]]
4444
;
4545
%y = call i8 @gen8()
4646
%t0 = mul i8 %y, %x ; swapped
@@ -52,10 +52,10 @@ define i1 @t2_commutative(i8 %x) {
5252
define i1 @t3_commutative(i8 %x) {
5353
; CHECK-LABEL: @t3_commutative(
5454
; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8()
55-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
56-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
57-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
58-
; CHECK-NEXT: ret i1 [[R]]
55+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
56+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
57+
; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
58+
; CHECK-NEXT: ret i1 [[MUL_NOT_OV]]
5959
;
6060
%y = call i8 @gen8()
6161
%t0 = mul i8 %y, %x ; swapped
@@ -67,10 +67,10 @@ define i1 @t3_commutative(i8 %x) {
6767
define i1 @t4_commutative(i8 %x) {
6868
; CHECK-LABEL: @t4_commutative(
6969
; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8()
70-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
71-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
72-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[Y]], [[T1]]
73-
; CHECK-NEXT: ret i1 [[R]]
70+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
71+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
72+
; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
73+
; CHECK-NEXT: ret i1 [[MUL_NOT_OV]]
7474
;
7575
%y = call i8 @gen8()
7676
%t0 = mul i8 %y, %x ; swapped
@@ -85,11 +85,12 @@ declare void @use8(i8)
8585

8686
define i1 @t5_extrause0(i8 %x, i8 %y) {
8787
; CHECK-LABEL: @t5_extrause0(
88-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
89-
; CHECK-NEXT: call void @use8(i8 [[T0]])
90-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
91-
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
92-
; CHECK-NEXT: ret i1 [[R]]
88+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
89+
; CHECK-NEXT: [[MUL_VAL:%.*]] = extractvalue { i8, i1 } [[MUL]], 0
90+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
91+
; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
92+
; CHECK-NEXT: call void @use8(i8 [[MUL_VAL]])
93+
; CHECK-NEXT: ret i1 [[MUL_NOT_OV]]
9394
;
9495
%t0 = mul i8 %x, %y
9596
call void @use8(i8 %t0)

llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
define i1 @t0_basic(i8 %x, i8 %y) {
1010
; CHECK-LABEL: @t0_basic(
11-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
12-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
13-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
14-
; CHECK-NEXT: ret i1 [[R]]
11+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
12+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
13+
; CHECK-NEXT: ret i1 [[MUL_OV]]
1514
;
1615
%t0 = mul i8 %x, %y
1716
%t1 = sdiv i8 %t0, %x
@@ -21,10 +20,9 @@ define i1 @t0_basic(i8 %x, i8 %y) {
2120

2221
define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) {
2322
; CHECK-LABEL: @t1_vec(
24-
; CHECK-NEXT: [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]]
25-
; CHECK-NEXT: [[T1:%.*]] = sdiv <2 x i8> [[T0]], [[X]]
26-
; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[T1]], [[Y]]
27-
; CHECK-NEXT: ret <2 x i1> [[R]]
23+
; CHECK-NEXT: [[MUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.smul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]])
24+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[MUL]], 1
25+
; CHECK-NEXT: ret <2 x i1> [[MUL_OV]]
2826
;
2927
%t0 = mul <2 x i8> %x, %y
3028
%t1 = sdiv <2 x i8> %t0, %x
@@ -37,10 +35,9 @@ declare i8 @gen8()
3735
define i1 @t2_commutative(i8 %x) {
3836
; CHECK-LABEL: @t2_commutative(
3937
; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8()
40-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
41-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
42-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
43-
; CHECK-NEXT: ret i1 [[R]]
38+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
39+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
40+
; CHECK-NEXT: ret i1 [[MUL_OV]]
4441
;
4542
%y = call i8 @gen8()
4643
%t0 = mul i8 %y, %x ; swapped
@@ -52,10 +49,9 @@ define i1 @t2_commutative(i8 %x) {
5249
define i1 @t3_commutative(i8 %x) {
5350
; CHECK-LABEL: @t3_commutative(
5451
; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8()
55-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
56-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
57-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
58-
; CHECK-NEXT: ret i1 [[R]]
52+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
53+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
54+
; CHECK-NEXT: ret i1 [[MUL_OV]]
5955
;
6056
%y = call i8 @gen8()
6157
%t0 = mul i8 %y, %x ; swapped
@@ -67,10 +63,9 @@ define i1 @t3_commutative(i8 %x) {
6763
define i1 @t4_commutative(i8 %x) {
6864
; CHECK-LABEL: @t4_commutative(
6965
; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8()
70-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
71-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
72-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[Y]], [[T1]]
73-
; CHECK-NEXT: ret i1 [[R]]
66+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
67+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
68+
; CHECK-NEXT: ret i1 [[MUL_OV]]
7469
;
7570
%y = call i8 @gen8()
7671
%t0 = mul i8 %y, %x ; swapped
@@ -85,11 +80,11 @@ declare void @use8(i8)
8580

8681
define i1 @t5_extrause0(i8 %x, i8 %y) {
8782
; CHECK-LABEL: @t5_extrause0(
88-
; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
89-
; CHECK-NEXT: call void @use8(i8 [[T0]])
90-
; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
91-
; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
92-
; CHECK-NEXT: ret i1 [[R]]
83+
; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
84+
; CHECK-NEXT: [[MUL_VAL:%.*]] = extractvalue { i8, i1 } [[MUL]], 0
85+
; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
86+
; CHECK-NEXT: call void @use8(i8 [[MUL_VAL]])
87+
; CHECK-NEXT: ret i1 [[MUL_OV]]
9388
;
9489
%t0 = mul i8 %x, %y
9590
call void @use8(i8 %t0)

0 commit comments

Comments
 (0)