Skip to content

Commit 75b0c89

Browse files
authored
[InstCombine][VectorCombine][NFC] Unify uses of lossless inverse cast (#156597)
This patch addresses #155216 (comment). This patch adds a helper function to put the inverse cast on constants, with cast flags preserved(optional). Follow-up patches will add trunc/ext handling on VectorCombine and flags preservation on InstCombine.
1 parent b9f571f commit 75b0c89

File tree

11 files changed

+86
-74
lines changed

11 files changed

+86
-74
lines changed

llvm/include/llvm/Analysis/ConstantFolding.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,27 @@ LLVM_ABI bool isMathLibCallNoop(const CallBase *Call,
226226

227227
LLVM_ABI Constant *ReadByteArrayFromGlobal(const GlobalVariable *GV,
228228
uint64_t Offset);
229-
}
229+
230+
struct PreservedCastFlags {
231+
bool NNeg = false;
232+
bool NUW = false;
233+
bool NSW = false;
234+
};
235+
236+
/// Try to cast C to InvC losslessly, satisfying CastOp(InvC) equals C, or
237+
/// CastOp(InvC) is a refined value of undefined C. Will try best to
238+
/// preserve the flags.
239+
LLVM_ABI Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
240+
unsigned CastOp, const DataLayout &DL,
241+
PreservedCastFlags *Flags = nullptr);
242+
243+
LLVM_ABI Constant *
244+
getLosslessUnsignedTrunc(Constant *C, Type *DestTy, const DataLayout &DL,
245+
PreservedCastFlags *Flags = nullptr);
246+
247+
LLVM_ABI Constant *getLosslessSignedTrunc(Constant *C, Type *DestTy,
248+
const DataLayout &DL,
249+
PreservedCastFlags *Flags = nullptr);
250+
} // namespace llvm
230251

231252
#endif

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4608,4 +4608,55 @@ bool llvm::isMathLibCallNoop(const CallBase *Call,
46084608
return false;
46094609
}
46104610

4611+
Constant *llvm::getLosslessInvCast(Constant *C, Type *InvCastTo,
4612+
unsigned CastOp, const DataLayout &DL,
4613+
PreservedCastFlags *Flags) {
4614+
switch (CastOp) {
4615+
case Instruction::BitCast:
4616+
// Bitcast is always lossless.
4617+
return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
4618+
case Instruction::Trunc: {
4619+
auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
4620+
if (Flags) {
4621+
// Truncation back on ZExt value is always NUW.
4622+
Flags->NUW = true;
4623+
// Test positivity of C.
4624+
auto *SExtC =
4625+
ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
4626+
Flags->NSW = ZExtC == SExtC;
4627+
}
4628+
return ZExtC;
4629+
}
4630+
case Instruction::SExt:
4631+
case Instruction::ZExt: {
4632+
auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
4633+
auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
4634+
// Must satisfy CastOp(InvC) == C.
4635+
if (!CastInvC || CastInvC != C)
4636+
return nullptr;
4637+
if (Flags && CastOp == Instruction::ZExt) {
4638+
auto *SExtInvC =
4639+
ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
4640+
// Test positivity of InvC.
4641+
Flags->NNeg = CastInvC == SExtInvC;
4642+
}
4643+
return InvC;
4644+
}
4645+
default:
4646+
return nullptr;
4647+
}
4648+
}
4649+
4650+
Constant *llvm::getLosslessUnsignedTrunc(Constant *C, Type *DestTy,
4651+
const DataLayout &DL,
4652+
PreservedCastFlags *Flags) {
4653+
return getLosslessInvCast(C, DestTy, Instruction::ZExt, DL, Flags);
4654+
}
4655+
4656+
Constant *llvm::getLosslessSignedTrunc(Constant *C, Type *DestTy,
4657+
const DataLayout &DL,
4658+
PreservedCastFlags *Flags) {
4659+
return getLosslessInvCast(C, DestTy, Instruction::SExt, DL, Flags);
4660+
}
4661+
46114662
void TargetFolder::anchor() {}

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,16 +1799,17 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
17991799
// type may provide more information to later folds, and the smaller logic
18001800
// instruction may be cheaper (particularly in the case of vectors).
18011801
Value *X;
1802+
auto &DL = IC.getDataLayout();
18021803
if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) {
1803-
if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) {
1804+
if (Constant *TruncC = getLosslessUnsignedTrunc(C, SrcTy, DL)) {
18041805
// LogicOpc (zext X), C --> zext (LogicOpc X, C)
18051806
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
18061807
return new ZExtInst(NewOp, DestTy);
18071808
}
18081809
}
18091810

18101811
if (match(Cast, m_OneUse(m_SExtLike(m_Value(X))))) {
1811-
if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) {
1812+
if (Constant *TruncC = getLosslessSignedTrunc(C, SrcTy, DL)) {
18121813
// LogicOpc (sext X), C --> sext (LogicOpc X, C)
18131814
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
18141815
return new SExtInst(NewOp, DestTy);

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,7 +1956,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
19561956
Constant *C;
19571957
if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) &&
19581958
I0->hasOneUse()) {
1959-
if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) {
1959+
if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType(), DL)) {
19601960
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
19611961
return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
19621962
}
@@ -2006,7 +2006,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
20062006
Constant *C;
20072007
if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) &&
20082008
I0->hasOneUse()) {
2009-
if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) {
2009+
if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType(), DL)) {
20102010
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
20112011
return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType());
20122012
}

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6336,7 +6336,7 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
63366336

63376337
// If a lossless truncate is possible...
63386338
Type *SrcTy = CastOp0->getSrcTy();
6339-
Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode());
6339+
Constant *Res = getLosslessInvCast(C, SrcTy, CastOp0->getOpcode(), DL);
63406340
if (Res) {
63416341
if (ICmp.isEquality())
63426342
return new ICmpInst(ICmp.getPredicate(), X, Res);

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -222,23 +222,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
222222
bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
223223
const Instruction *CtxI) const;
224224

225-
Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) {
226-
Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
227-
Constant *ExtTruncC =
228-
ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL);
229-
if (ExtTruncC && ExtTruncC == C)
230-
return TruncC;
231-
return nullptr;
232-
}
233-
234-
Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
235-
return getLosslessTrunc(C, TruncTy, Instruction::ZExt);
236-
}
237-
238-
Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) {
239-
return getLosslessTrunc(C, TruncTy, Instruction::SExt);
240-
}
241-
242225
std::optional<std::pair<Intrinsic::ID, SmallVector<Value *, 3>>>
243226
convertOrOfShiftsToFunnelShift(Instruction &Or);
244227

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,10 +1642,11 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
16421642
}
16431643

16441644
Constant *C;
1645+
auto &DL = IC.getDataLayout();
16451646
if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) &&
16461647
match(D, m_Constant(C))) {
16471648
// If the constant is the same in the smaller type, use the narrow version.
1648-
Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
1649+
Constant *TruncC = getLosslessUnsignedTrunc(C, X->getType(), DL);
16491650
if (!TruncC)
16501651
return nullptr;
16511652

@@ -1656,7 +1657,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
16561657
if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) &&
16571658
match(N, m_Constant(C))) {
16581659
// If the constant is the same in the smaller type, use the narrow version.
1659-
Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
1660+
Constant *TruncC = getLosslessUnsignedTrunc(C, X->getType(), DL);
16601661
if (!TruncC)
16611662
return nullptr;
16621663

llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
841841
NumZexts++;
842842
} else if (auto *C = dyn_cast<Constant>(V)) {
843843
// Make sure that constants can fit in the new type.
844-
Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType);
844+
Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType, DL);
845845
if (!Trunc)
846846
return nullptr;
847847
NewIncoming.push_back(Trunc);

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,7 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
23752375
// If the constant is the same after truncation to the smaller type and
23762376
// extension to the original type, we can narrow the select.
23772377
Type *SelType = Sel.getType();
2378-
Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode);
2378+
Constant *TruncC = getLosslessInvCast(C, SmallType, ExtOpcode, DL);
23792379
if (TruncC && ExtInst->hasOneUse()) {
23802380
Value *TruncCVal = cast<Value>(TruncC);
23812381
if (ExtInst == Sel.getFalseValue())

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2568,7 +2568,7 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) {
25682568
Constant *WideC;
25692569
if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC)))
25702570
return nullptr;
2571-
Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc);
2571+
Constant *NarrowC = getLosslessInvCast(WideC, X->getType(), CastOpc, DL);
25722572
if (!NarrowC)
25732573
return nullptr;
25742574
Y = NarrowC;

0 commit comments

Comments
 (0)