-
Notifications
You must be signed in to change notification settings - Fork 15k
[VectorCombine] Support pattern bitop(bitcast(x), C) -> bitcast(bitop(x, InvC))
#155216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,6 +122,7 @@ class VectorCombine { | |
| bool foldInsExtBinop(Instruction &I); | ||
| bool foldInsExtVectorToShuffle(Instruction &I); | ||
| bool foldBitOpOfCastops(Instruction &I); | ||
| bool foldBitOpOfCastConstant(Instruction &I); | ||
| bool foldBitcastShuffle(Instruction &I); | ||
| bool scalarizeOpOrCmp(Instruction &I); | ||
| bool scalarizeVPIntrinsic(Instruction &I); | ||
|
|
@@ -937,6 +938,146 @@ bool VectorCombine::foldBitOpOfCastops(Instruction &I) { | |
| return true; | ||
| } | ||
|
|
||
| struct PreservedCastFlags { | ||
| bool NNeg = false; | ||
| bool NUW = false; | ||
| bool NSW = false; | ||
| }; | ||
|
|
||
| // Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C. | ||
| // Will try best to preserve the flags. | ||
| static Constant *getLosslessInvCast(Constant *C, Type *InvCastTo, | ||
|
||
| Instruction::CastOps CastOp, | ||
| const DataLayout &DL, | ||
| PreservedCastFlags &Flags) { | ||
| switch (CastOp) { | ||
| case Instruction::BitCast: | ||
| // Bitcast is always lossless. | ||
| return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL); | ||
| case Instruction::Trunc: { | ||
| auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL); | ||
| auto *SExtC = ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL); | ||
| // Truncation back on ZExt value is always NUW. | ||
| Flags.NUW = true; | ||
| // Test positivity of C. | ||
| Flags.NSW = ZExtC == SExtC; | ||
| return ZExtC; | ||
| } | ||
| case Instruction::SExt: | ||
| case Instruction::ZExt: { | ||
| auto *InvC = ConstantExpr::getTrunc(C, InvCastTo); | ||
| auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL); | ||
| // Must satisfy CastOp(InvC) == C. | ||
| if (!CastInvC || CastInvC != C) | ||
| return nullptr; | ||
| if (CastOp == Instruction::ZExt) { | ||
| auto *SExtInvC = | ||
| ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL); | ||
| // Test positivity of InvC. | ||
| Flags.NNeg = CastInvC == SExtInvC; | ||
| } | ||
| return InvC; | ||
| } | ||
| default: | ||
| return nullptr; | ||
| } | ||
| } | ||
|
|
||
| /// Match: | ||
| // bitop(castop(x), C) -> | ||
| // bitop(castop(x), castop(InvC)) -> | ||
| // castop(bitop(x, InvC)) | ||
| // Supports: bitcast | ||
| bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) { | ||
| Instruction *LHS; | ||
| Constant *C; | ||
|
|
||
| // Check if this is a bitwise logic operation | ||
| if (!match(&I, m_c_BitwiseLogic(m_Instruction(LHS), m_Constant(C)))) | ||
| return false; | ||
|
|
||
| // Get the cast instructions | ||
| auto *LHSCast = dyn_cast<CastInst>(LHS); | ||
| if (!LHSCast) | ||
| return false; | ||
|
|
||
| Instruction::CastOps CastOpcode = LHSCast->getOpcode(); | ||
|
|
||
| // Only handle supported cast operations | ||
| switch (CastOpcode) { | ||
| case Instruction::BitCast: | ||
| break; | ||
| default: | ||
| return false; | ||
| } | ||
|
|
||
| Value *LHSSrc = LHSCast->getOperand(0); | ||
|
|
||
| // Only handle vector types with integer elements | ||
| auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType()); | ||
| auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType()); | ||
| if (!SrcVecTy || !DstVecTy) | ||
| return false; | ||
|
|
||
| if (!SrcVecTy->getScalarType()->isIntegerTy() || | ||
| !DstVecTy->getScalarType()->isIntegerTy()) | ||
| return false; | ||
|
|
||
| // Find the constant InvC, such that castop(InvC) equals to C. | ||
| PreservedCastFlags RHSFlags; | ||
| Constant *InvC = getLosslessInvCast(C, SrcVecTy, CastOpcode, *DL, RHSFlags); | ||
| if (!InvC) | ||
| return false; | ||
|
|
||
| // Cost Check : | ||
| // OldCost = bitlogic + cast | ||
| // NewCost = bitlogic + cast | ||
|
|
||
| // Calculate specific costs for each cast with instruction context | ||
| InstructionCost LHSCastCost = | ||
| TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
| TTI::CastContextHint::None, CostKind, LHSCast); | ||
|
|
||
| InstructionCost OldCost = | ||
| TTI.getArithmeticInstrCost(I.getOpcode(), DstVecTy, CostKind) + | ||
| LHSCastCost; | ||
|
|
||
| // For new cost, we can't provide an instruction (it doesn't exist yet) | ||
| InstructionCost GenericCastCost = TTI.getCastInstrCost( | ||
| CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind); | ||
|
|
||
| InstructionCost NewCost = | ||
| TTI.getArithmeticInstrCost(I.getOpcode(), SrcVecTy, CostKind) + | ||
| GenericCastCost; | ||
|
|
||
| // Account for multi-use casts using specific costs | ||
| if (!LHSCast->hasOneUse()) | ||
| NewCost += LHSCastCost; | ||
|
|
||
| LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost | ||
| << " NewCost=" << NewCost << "\n"); | ||
|
|
||
| if (NewCost > OldCost) | ||
| return false; | ||
|
|
||
| // Create the operation on the source type | ||
| Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), | ||
| LHSSrc, InvC, I.getName() + ".inner"); | ||
| if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp)) | ||
| NewBinOp->copyIRFlags(&I); | ||
|
|
||
| Worklist.pushValue(NewOp); | ||
|
|
||
| // Create the cast operation directly to ensure we get a new instruction | ||
| Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType()); | ||
|
|
||
| // Insert the new instruction | ||
| Value *Result = Builder.Insert(NewCast); | ||
|
|
||
| replaceValue(I, *Result); | ||
| return true; | ||
| } | ||
|
|
||
| /// If this is a bitcast of a shuffle, try to bitcast the source vector to the | ||
| /// destination type followed by shuffle. This can enable further transforms by | ||
| /// moving bitcasts or shuffles together. | ||
|
|
@@ -4474,6 +4615,8 @@ bool VectorCombine::run() { | |
| case Instruction::Xor: | ||
| if (foldBitOpOfCastops(I)) | ||
| return true; | ||
| if (foldBitOpOfCastConstant(I)) | ||
| return true; | ||
| break; | ||
| case Instruction::PHI: | ||
| if (shrinkPhiOfShuffles(I)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this? Especially if we're exposing getLosslessInvCastas a general helper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it's necessary. For users like InstCombine, they always try to preserve the flags as well as possible. With this, we can preserve flags more easily. For example, apply the helper in:
llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Lines 1787 to 1819 in 16494be