Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class VectorCombine {
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
bool foldBitOpOfCastops(Instruction &I);
bool foldBitOpOfCastConstant(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
Expand Down Expand Up @@ -937,6 +938,146 @@ bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
return true;
}

struct PreservedCastFlags {
bool NNeg = false;
bool NUW = false;
bool NSW = false;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this? Especially if we're exposing getLosslessInvCastas a general helper

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's necessary. For users like InstCombine, they always try to preserve the flags as well as possible. With this, we can preserve flags more easily. For example, apply the helper in:

static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
InstCombinerImpl &IC) {
Constant *C = dyn_cast<Constant>(Logic.getOperand(1));
if (!C)
return nullptr;
auto LogicOpc = Logic.getOpcode();
Type *DestTy = Logic.getType();
Type *SrcTy = Cast->getSrcTy();
// Move the logic operation ahead of a zext or sext if the constant is
// unchanged in the smaller source type. Performing the logic in a smaller
// type may provide more information to later folds, and the smaller logic
// instruction may be cheaper (particularly in the case of vectors).
Value *X;
if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) {
if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) {
// LogicOpc (zext X), C --> zext (LogicOpc X, C)
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new ZExtInst(NewOp, DestTy);
}
}
if (match(Cast, m_OneUse(m_SExtLike(m_Value(X))))) {
if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) {
// LogicOpc (sext X), C --> sext (LogicOpc X, C)
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new SExtInst(NewOp, DestTy);
}
}
return nullptr;
}


// Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C.
// Will try best to preserve the flags.
static Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could put this in InstCombineInternal.h where we have more methods like getLosslessUnsignedTrunc, getLosslessSignedTrunc, etc. ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's possible to extract this as a public API, but I am not sure InstCombineInternal.h is a good place. Maybe InstructionSimplify.h is more proper.
And by this API, we can extend foldLogicCastConstant in InstCombine. Anyway, I would refactor and extend in a new patch.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikic Any preference on where this helper function is moved to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is based on ConstantFolding APIs, maybe ConstantFolding.h may be a good place?

Instruction::CastOps CastOp,
const DataLayout &DL,
PreservedCastFlags &Flags) {
switch (CastOp) {
case Instruction::BitCast:
// Bitcast is always lossless.
return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
case Instruction::Trunc: {
auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
auto *SExtC = ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
// Truncation back on ZExt value is always NUW.
Flags.NUW = true;
// Test positivity of C.
Flags.NSW = ZExtC == SExtC;
return ZExtC;
}
case Instruction::SExt:
case Instruction::ZExt: {
auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
// Must satisfy CastOp(InvC) == C.
if (!CastInvC || CastInvC != C)
return nullptr;
if (CastOp == Instruction::ZExt) {
auto *SExtInvC =
ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
// Test positivity of InvC.
Flags.NNeg = CastInvC == SExtInvC;
}
return InvC;
}
default:
return nullptr;
}
}

/// Match:
// bitop(castop(x), C) ->
// bitop(castop(x), castop(InvC)) ->
// castop(bitop(x, InvC))
// Supports: bitcast
bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
Instruction *LHS;
Constant *C;

// Check if this is a bitwise logic operation
if (!match(&I, m_c_BitwiseLogic(m_Instruction(LHS), m_Constant(C))))
return false;

// Get the cast instructions
auto *LHSCast = dyn_cast<CastInst>(LHS);
if (!LHSCast)
return false;

Instruction::CastOps CastOpcode = LHSCast->getOpcode();

// Only handle supported cast operations
switch (CastOpcode) {
case Instruction::BitCast:
break;
default:
return false;
}

Value *LHSSrc = LHSCast->getOperand(0);

// Only handle vector types with integer elements
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
if (!SrcVecTy || !DstVecTy)
return false;

if (!SrcVecTy->getScalarType()->isIntegerTy() ||
!DstVecTy->getScalarType()->isIntegerTy())
return false;

// Find the constant InvC, such that castop(InvC) equals to C.
PreservedCastFlags RHSFlags;
Constant *InvC = getLosslessInvCast(C, SrcVecTy, CastOpcode, *DL, RHSFlags);
if (!InvC)
return false;

// Cost Check :
// OldCost = bitlogic + cast
// NewCost = bitlogic + cast

// Calculate specific costs for each cast with instruction context
InstructionCost LHSCastCost =
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None, CostKind, LHSCast);

InstructionCost OldCost =
TTI.getArithmeticInstrCost(I.getOpcode(), DstVecTy, CostKind) +
LHSCastCost;

// For new cost, we can't provide an instruction (it doesn't exist yet)
InstructionCost GenericCastCost = TTI.getCastInstrCost(
CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind);

InstructionCost NewCost =
TTI.getArithmeticInstrCost(I.getOpcode(), SrcVecTy, CostKind) +
GenericCastCost;

// Account for multi-use casts using specific costs
if (!LHSCast->hasOneUse())
NewCost += LHSCastCost;

LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
<< " NewCost=" << NewCost << "\n");

if (NewCost > OldCost)
return false;

// Create the operation on the source type
Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(),
LHSSrc, InvC, I.getName() + ".inner");
if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
NewBinOp->copyIRFlags(&I);

Worklist.pushValue(NewOp);

// Create the cast operation directly to ensure we get a new instruction
Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());

// Insert the new instruction
Value *Result = Builder.Insert(NewCast);

replaceValue(I, *Result);
return true;
}

/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
/// destination type followed by shuffle. This can enable further transforms by
/// moving bitcasts or shuffles together.
Expand Down Expand Up @@ -4474,6 +4615,8 @@ bool VectorCombine::run() {
case Instruction::Xor:
if (foldBitOpOfCastops(I))
return true;
if (foldBitOpOfCastConstant(I))
return true;
break;
case Instruction::PHI:
if (shrinkPhiOfShuffles(I))
Expand Down
160 changes: 160 additions & 0 deletions llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,163 @@ define <4 x i32> @or_zext_nneg(<4 x i16> %a, <4 x i16> %b) {
%or = or <4 x i32> %z1, %z2
ret <4 x i32> %or
}

; Test bitwise operations with integer-to-integer bitcast with one constant
define <2 x i32> @or_bitcast_v4i16_to_v2i32_constant(<4 x i16> %a) {
; CHECK-LABEL: @or_bitcast_v4i16_to_v2i32_constant(
; CHECK-NEXT: [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 16960, i16 15, i16 -31616, i16 30>
; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x i16> [[A]] to <2 x i32>
; CHECK-NEXT: ret <2 x i32> [[BC1]]
;
%bc1 = bitcast <4 x i16> %a to <2 x i32>
%or = or <2 x i32> %bc1, <i32 1000000, i32 2000000>
ret <2 x i32> %or
}

define <2 x i32> @or_bitcast_v4i16_to_v2i32_constant_commuted(<4 x i16> %a) {
; CHECK-LABEL: @or_bitcast_v4i16_to_v2i32_constant_commuted(
; CHECK-NEXT: [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 16960, i16 15, i16 -31616, i16 30>
; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x i16> [[A]] to <2 x i32>
; CHECK-NEXT: ret <2 x i32> [[BC1]]
;
%bc1 = bitcast <4 x i16> %a to <2 x i32>
%or = or <2 x i32> <i32 1000000, i32 2000000>, %bc1
ret <2 x i32> %or
}

; Test bitwise operations with truncate and one constant
define <4 x i16> @or_trunc_v4i32_to_v4i16_constant(<4 x i32> %a) {
; CHECK-LABEL: @or_trunc_v4i32_to_v4i16_constant(
; CHECK-NEXT: [[T1:%.*]] = trunc <4 x i32> [[A:%.*]] to <4 x i16>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i16> [[T1]], <i16 1, i16 2, i16 3, i16 4>
; CHECK-NEXT: ret <4 x i16> [[OR]]
;
%t1 = trunc <4 x i32> %a to <4 x i16>
%or = or <4 x i16> %t1, <i16 1, i16 2, i16 3, i16 4>
ret <4 x i16> %or
}

; Test bitwise operations with zero extend and one constant
define <4 x i32> @or_zext_v4i16_to_v4i32_constant(<4 x i16> %a) {
; CHECK-LABEL: @or_zext_v4i16_to_v4i32_constant(
; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1, i32 2, i32 3, i32 4>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%z1 = zext <4 x i16> %a to <4 x i32>
%or = or <4 x i32> %z1, <i32 1, i32 2, i32 3, i32 4>
ret <4 x i32> %or
}

define <4 x i32> @or_zext_v4i8_to_v4i32_constant_with_loss(<4 x i8> %a) {
; CHECK-LABEL: @or_zext_v4i8_to_v4i32_constant_with_loss(
; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i8> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1024, i32 129, i32 3, i32 4>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%z1 = zext <4 x i8> %a to <4 x i32>
%or = or <4 x i32> %z1, <i32 1024, i32 129, i32 3, i32 4>
ret <4 x i32> %or
}

; Test bitwise operations with sign extend and one constant
define <4 x i32> @or_sext_v4i8_to_v4i32_positive_constant(<4 x i8> %a) {
; CHECK-LABEL: @or_sext_v4i8_to_v4i32_positive_constant(
; CHECK-NEXT: [[S1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[S1]], <i32 1, i32 2, i32 3, i32 4>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%s1 = sext <4 x i8> %a to <4 x i32>
%or = or <4 x i32> %s1, <i32 1, i32 2, i32 3, i32 4>
ret <4 x i32> %or
}

define <4 x i32> @or_sext_v4i8_to_v4i32_minus_constant(<4 x i8> %a) {
; CHECK-LABEL: @or_sext_v4i8_to_v4i32_minus_constant(
; CHECK-NEXT: [[S1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[S1]], <i32 -1, i32 -2, i32 -3, i32 -4>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%s1 = sext <4 x i8> %a to <4 x i32>
%or = or <4 x i32> %s1, <i32 -1, i32 -2, i32 -3, i32 -4>
ret <4 x i32> %or
}

define <4 x i32> @or_sext_v4i8_to_v4i32_constant_with_loss(<4 x i8> %a) {
; CHECK-LABEL: @or_sext_v4i8_to_v4i32_constant_with_loss(
; CHECK-NEXT: [[Z1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 -10000, i32 2, i32 3, i32 4>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%z1 = sext <4 x i8> %a to <4 x i32>
%or = or <4 x i32> %z1, <i32 -10000, i32 2, i32 3, i32 4>
ret <4 x i32> %or
}

; Test truncate with flag preservation and one constant
define <4 x i16> @and_trunc_nuw_nsw_constant(<4 x i32> %a) {
; CHECK-LABEL: @and_trunc_nuw_nsw_constant(
; CHECK-NEXT: [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A:%.*]] to <4 x i16>
; CHECK-NEXT: [[AND:%.*]] = and <4 x i16> [[T1]], <i16 1, i16 2, i16 3, i16 4>
; CHECK-NEXT: ret <4 x i16> [[AND]]
;
%t1 = trunc nuw nsw <4 x i32> %a to <4 x i16>
%and = and <4 x i16> %t1, <i16 1, i16 2, i16 3, i16 4>
ret <4 x i16> %and
}

define <4 x i8> @and_trunc_nuw_nsw_minus_constant(<4 x i32> %a) {
; CHECK-LABEL: @and_trunc_nuw_nsw_minus_constant(
; CHECK-NEXT: [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A:%.*]] to <4 x i8>
; CHECK-NEXT: [[AND:%.*]] = and <4 x i8> [[T1]], <i8 -16, i8 -15, i8 -14, i8 -13>
; CHECK-NEXT: ret <4 x i8> [[AND]]
;
%t1 = trunc nuw nsw <4 x i32> %a to <4 x i8>
%and = and <4 x i8> %t1, <i8 240, i8 241, i8 242, i8 243>
ret <4 x i8> %and
}

define <4 x i8> @and_trunc_nuw_nsw_multiconstant(<4 x i32> %a) {
; CHECK-LABEL: @and_trunc_nuw_nsw_multiconstant(
; CHECK-NEXT: [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A:%.*]] to <4 x i8>
; CHECK-NEXT: [[AND:%.*]] = and <4 x i8> [[T1]], <i8 -16, i8 1, i8 -14, i8 3>
; CHECK-NEXT: ret <4 x i8> [[AND]]
;
%t1 = trunc nuw nsw <4 x i32> %a to <4 x i8>
%and = and <4 x i8> %t1, <i8 240, i8 1, i8 242, i8 3>
ret <4 x i8> %and
}

; Test sign extend with nneg flag and one constant
define <4 x i32> @or_zext_nneg_constant(<4 x i16> %a) {
; CHECK-LABEL: @or_zext_nneg_constant(
; CHECK-NEXT: [[Z1:%.*]] = zext nneg <4 x i16> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1, i32 2, i32 3, i32 4>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%z1 = zext nneg <4 x i16> %a to <4 x i32>
%or = or <4 x i32> %z1, <i32 1, i32 2, i32 3, i32 4>
ret <4 x i32> %or
}

define <4 x i32> @or_zext_nneg_minus_constant(<4 x i8> %a) {
; CHECK-LABEL: @or_zext_nneg_minus_constant(
; CHECK-NEXT: [[Z1:%.*]] = zext nneg <4 x i8> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 240, i32 241, i32 242, i32 243>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%z1 = zext nneg <4 x i8> %a to <4 x i32>
%or = or <4 x i32> %z1, <i32 240, i32 241, i32 242, i32 243>
ret <4 x i32> %or
}

define <4 x i32> @or_zext_nneg_multiconstant(<4 x i8> %a) {
; CHECK-LABEL: @or_zext_nneg_multiconstant(
; CHECK-NEXT: [[Z1:%.*]] = zext nneg <4 x i8> [[A:%.*]] to <4 x i32>
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 240, i32 1, i32 242, i32 3>
; CHECK-NEXT: ret <4 x i32> [[OR]]
;
%z1 = zext nneg <4 x i8> %a to <4 x i32>
%or = or <4 x i32> %z1, <i32 240, i32 1, i32 242, i32 3>
ret <4 x i32> %or
}