Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VectorCombine] Add foldShuffleToIdentity #88693

Merged
merged 6 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class VectorCombine {
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldTruncFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
Expand Down Expand Up @@ -1667,6 +1668,151 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
return true;
}

// Starting from a shuffle, look up through operands tracking the shuffled index
// of each lane. If we can simplify away the shuffles to identities then
// do so.
bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
auto *Ty = dyn_cast<FixedVectorType>(I.getType());
if (!Ty || !isa<Instruction>(I.getOperand(0)) ||
!isa<Instruction>(I.getOperand(1)))
return false;

using InstLane = std::pair<Value *, int>;

auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
unsigned NumElts =
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
int M = SV->getMaskValue(Lane);
if (M < 0)
return {nullptr, PoisonMaskElem};
else if (M < (int)NumElts) {
V = SV->getOperand(0);
Lane = M;
} else {
V = SV->getOperand(1);
Lane = M - NumElts;
}
}
return InstLane{V, Lane};
};

auto GenerateInstLaneVectorFromOperand =
[&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
SmallVector<InstLane> NItem;
for (InstLane V : Item) {
NItem.emplace_back(
!V.first
? InstLane{nullptr, PoisonMaskElem}
: LookThroughShuffles(
cast<Instruction>(V.first)->getOperand(Op), V.second));
}
return NItem;
};

SmallVector<InstLane> Start(Ty->getNumElements());
for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
Start[M] = LookThroughShuffles(&I, M);

SmallVector<SmallVector<InstLane>> Worklist;
Worklist.push_back(Start);
SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs;
unsigned NumVisited = 0;

while (!Worklist.empty()) {
SmallVector<InstLane> Item = Worklist.pop_back_val();
if (++NumVisited > MaxInstrsToScan)
return false;

// If we found an undef first lane then bail out to keep things simple.
if (!Item[0].first)
return false;

// Look for an identity value.
if (Item[0].second == 0 && Item[0].first->getType() == Ty &&
all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
return !E.value().first || (E.value().first == Item[0].first &&
E.value().second == (int)E.index());
})) {
IdentityLeafs.insert(Item[0].first);
continue;
}
// Look for a splat value.
if (all_of(drop_begin(Item), [&](InstLane &IL) {
return !IL.first ||
(IL.first == Item[0].first && IL.second == Item[0].second);
})) {
RKSimon marked this conversation as resolved.
Show resolved Hide resolved
SplatLeafs.insert(Item[0].first);
continue;
}

// We need each element to be the same type of value, and check that each
// element has a single use.
if (!all_of(drop_begin(Item), [&](InstLane IL) {
if (!IL.first)
return true;
if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
return false;
if (IL.first->getValueID() != Item[0].first->getValueID())
return false;
auto *II = dyn_cast<IntrinsicInst>(IL.first);
return !II ||
II->getIntrinsicID() ==
cast<IntrinsicInst>(Item[0].first)->getIntrinsicID();
}))
return false;

// Check the operator is one that we support. We exclude div/rem in case
// they hit UB from poison lanes.
if (isa<BinaryOperator>(Item[0].first) &&
!cast<BinaryOperator>(Item[0].first)->isIntDivRem()) {
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
} else if (isa<UnaryOperator>(Item[0].first)) {
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
} else {
return false;
}
}

// If we got this far, we know the shuffles are superfluous and can be
// removed. Scan through again and generate the new tree of instructions.
std::function<Value *(ArrayRef<InstLane>)> Generate =
[&](ArrayRef<InstLane> Item) -> Value * {
if (IdentityLeafs.contains(Item[0].first) &&
all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
return !E.value().first || (E.value().first == Item[0].first &&
E.value().second == (int)E.index());
})) {
return Item[0].first;
}
if (SplatLeafs.contains(Item[0].first)) {
if (auto ILI = dyn_cast<Instruction>(Item[0].first))
Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
else if (isa<Argument>(Item[0].first))
Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
return Builder.CreateShuffleVector(Item[0].first, Mask);
}

auto *I = cast<Instruction>(Item[0].first);
SmallVector<Value *> Ops(I->getNumOperands());
for (unsigned Idx = 0, E = I->getNumOperands(); Idx < E; Idx++)
Ops[Idx] = Generate(GenerateInstLaneVectorFromOperand(Item, Idx));
Builder.SetInsertPoint(I);
if (auto BI = dyn_cast<BinaryOperator>(I))
return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
Ops[0], Ops[1]);
assert(isa<UnaryInstruction>(I) &&
"Unexpected instruction type in Generate");
return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
};

Value *V = Generate(Start);
replaceValue(I, *V);
return true;
}

/// Given a commutative reduction, the order of the input lanes does not alter
/// the results. We can use this to remove certain shuffles feeding the
/// reduction, removing the need to shuffle at all.
Expand Down Expand Up @@ -2224,6 +2370,7 @@ bool VectorCombine::run() {
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
MadeChange |= foldSelectShuffle(I);
MadeChange |= foldShuffleToIdentity(I);
break;
case Instruction::BitCast:
MadeChange |= foldBitcastShuffle(I);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ define void @add4(ptr noalias noundef %x, ptr noalias noundef %y, i32 noundef %n
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <32 x i16>, ptr [[TMP0]], align 2
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
; CHECK-NEXT: [[WIDE_VEC24:%.*]] = load <32 x i16>, ptr [[TMP1]], align 2
; CHECK-NEXT: [[TMP2:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP3:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP4:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP5:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[TMP6:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <32 x i16> [[TMP2]], <32 x i16> [[TMP3]], <16 x i32> <i32 0, i32 4, i32 8, i32 12, i32 16, i32 20, i32 24, i32 28, i32 33, i32 37, i32 41, i32 45, i32 49, i32 53, i32 57, i32 61>
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <32 x i16> [[TMP4]], <32 x i16> [[TMP6]], <16 x i32> <i32 2, i32 6, i32 10, i32 14, i32 18, i32 22, i32 26, i32 30, i32 35, i32 39, i32 43, i32 47, i32 51, i32 55, i32 59, i32 63>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> [[TMP8]], <32 x i32> <i32 0, i32 8, i32 16, i32 24, i32 1, i32 9, i32 17, i32 25, i32 2, i32 10, i32 18, i32 26, i32 3, i32 11, i32 19, i32 27, i32 4, i32 12, i32 20, i32 28, i32 5, i32 13, i32 21, i32 29, i32 6, i32 14, i32 22, i32 30, i32 7, i32 15, i32 23, i32 31>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP2:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP2]]
; CHECK-NEXT: store <32 x i16> [[INTERLEAVED_VEC]], ptr [[GEP]], align 2
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP9]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: for.end:
; CHECK-NEXT: ret void
;
Expand Down Expand Up @@ -412,22 +406,13 @@ define void @addmul(ptr noalias noundef %x, ptr noundef %y, ptr noundef %z, i32
; CHECK-NEXT: [[TMP2:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
; CHECK-NEXT: [[WIDE_VEC36:%.*]] = load <32 x i16>, ptr [[TMP3]], align 2
; CHECK-NEXT: [[TMP4:%.*]] = add <32 x i16> [[TMP2]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP5:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP6:%.*]] = add <32 x i16> [[TMP5]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP7:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP8:%.*]] = add <32 x i16> [[TMP7]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP9:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[TMP10:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
; CHECK-NEXT: [[TMP11:%.*]] = add <32 x i16> [[TMP10]], [[WIDE_VEC36]]
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP9]]
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <32 x i16> [[TMP4]], <32 x i16> [[TMP6]], <16 x i32> <i32 0, i32 4, i32 8, i32 12, i32 16, i32 20, i32 24, i32 28, i32 33, i32 37, i32 41, i32 45, i32 49, i32 53, i32 57, i32 61>
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <32 x i16> [[TMP8]], <32 x i16> [[TMP11]], <16 x i32> <i32 2, i32 6, i32 10, i32 14, i32 18, i32 22, i32 26, i32 30, i32 35, i32 39, i32 43, i32 47, i32 51, i32 55, i32 59, i32 63>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <16 x i16> [[TMP12]], <16 x i16> [[TMP13]], <32 x i32> <i32 0, i32 8, i32 16, i32 24, i32 1, i32 9, i32 17, i32 25, i32 2, i32 10, i32 18, i32 26, i32 3, i32 11, i32 19, i32 27, i32 4, i32 12, i32 20, i32 28, i32 5, i32 13, i32 21, i32 29, i32 6, i32 14, i32 22, i32 30, i32 7, i32 15, i32 23, i32 31>
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[TMP2]], [[WIDE_VEC36]]
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP4]]
; CHECK-NEXT: store <32 x i16> [[INTERLEAVED_VEC]], ptr [[GEP]], align 2
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP14]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP5]], label [[FOR_END:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
; CHECK: for.end:
; CHECK-NEXT: ret void
;
Expand Down
Loading
Loading