Skip to content

Commit

Permalink
[VectorCombine] Fold "(or (zext (bitcast X)), (shl (zext (bitcast Y))…
Browse files Browse the repository at this point in the history
…, C))" -> "(bitcast (concat X, Y))" MOVMSK bool mask style patterns (llvm#119695)

Mask/Bool vectors are often bitcast to/from scalar integers, in particular when concatenating mask results, often this is due to the difficulties of working with vector of bools on C/C++. On x86 this typically involves the MOVMSK/KMOV instructions.

To concatenate bool masks, these are typically cast to scalars, which are then zero-extended, shifted and OR'd together.

This patch attempts to match these scalar concatenation patterns and convert them to vector shuffles instead. This in turn often assists with further vector combines, depending on the cost model.

Reapplied patch from llvm#119559 - fixed use after free issue.

Fixes llvm#111431
  • Loading branch information
RKSimon authored Dec 12, 2024
1 parent f9734b9 commit 86779da
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 114 deletions.
111 changes: 111 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class VectorCombine {
bool foldExtractedCmps(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
Expand Down Expand Up @@ -1423,6 +1424,113 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}

/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
/// to "(bitcast (concat X, Y))"
/// where X/Y are bitcasted from i1 mask vectors.
bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
Type *Ty = I.getType();
if (!Ty->isIntegerTy())
return false;

// TODO: Add big endian test coverage
if (DL->isBigEndian())
return false;

// Restrict to disjoint cases so the mask vectors aren't overlapping.
Instruction *X, *Y;
if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
return false;

// Allow both sources to contain shl, to handle more generic pattern:
// "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
Value *SrcX;
uint64_t ShAmtX = 0;
if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
!match(X, m_OneUse(
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
m_ConstantInt(ShAmtX)))))
return false;

Value *SrcY;
uint64_t ShAmtY = 0;
if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
!match(Y, m_OneUse(
m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
m_ConstantInt(ShAmtY)))))
return false;

// Canonicalize larger shift to the RHS.
if (ShAmtX > ShAmtY) {
std::swap(X, Y);
std::swap(SrcX, SrcY);
std::swap(ShAmtX, ShAmtY);
}

// Ensure both sources are matching vXi1 bool mask types, and that the shift
// difference is the mask width so they can be easily concatenated together.
uint64_t ShAmtDiff = ShAmtY - ShAmtX;
unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
unsigned BitWidth = Ty->getPrimitiveSizeInBits();
auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
if (!MaskTy || SrcX->getType() != SrcY->getType() ||
!MaskTy->getElementType()->isIntegerTy(1) ||
MaskTy->getNumElements() != ShAmtDiff ||
MaskTy->getNumElements() > (BitWidth / 2))
return false;

auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
auto *ConcatIntTy =
Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);

SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
std::iota(ConcatMask.begin(), ConcatMask.end(), 0);

// TODO: Is it worth supporting multi use cases?
InstructionCost OldCost = 0;
OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
OldCost +=
NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
TTI::CastContextHint::None, CostKind);
OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
TTI::CastContextHint::None, CostKind);

InstructionCost NewCost = 0;
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
ConcatMask, CostKind);
NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
TTI::CastContextHint::None, CostKind);
if (Ty != ConcatIntTy)
NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
TTI::CastContextHint::None, CostKind);
if (ShAmtX > 0)
NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);

if (NewCost > OldCost)
return false;

// Build bool mask concatenation, bitcast back to scalar integer, and perform
// any residual zero-extension or shifting.
Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
Worklist.pushValue(Concat);

Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);

if (Ty != ConcatIntTy) {
Worklist.pushValue(Result);
Result = Builder.CreateZExt(Result, Ty);
}

if (ShAmtX > 0) {
Worklist.pushValue(Result);
Result = Builder.CreateShl(Result, ShAmtX);
}

replaceValue(I, *Result);
return true;
}

/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
/// --> "binop (shuffle), (shuffle)".
bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
Expand Down Expand Up @@ -2945,6 +3053,9 @@ bool VectorCombine::run() {
case Instruction::FCmp:
MadeChange |= foldExtractExtract(I);
break;
case Instruction::Or:
MadeChange |= foldConcatOfBoolMasks(I);
[[fallthrough]];
default:
if (Instruction::isBinaryOp(Opcode)) {
MadeChange |= foldExtractExtract(I);
Expand Down
Loading

0 comments on commit 86779da

Please sign in to comment.