Skip to content

Commit 360e584

Browse files
authored
Merge pull request intel#181 from frasercrmck/vecz-subgroup-shuffle-xors
[vecz] Packetize sub-group shuffle_xor builtins
2 parents e73debc + 357a73e commit 360e584

File tree

2 files changed

+380
-10
lines changed

2 files changed

+380
-10
lines changed

llvm/lib/SYCLNativeCPUUtils/compiler_passes/vecz/source/transform/packetizer.cpp

Lines changed: 179 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,17 @@ class Packetizer::Impl : public Packetizer {
221221
///
222222
/// @return Packetized instructions.
223223
Value *packetizeSubgroupShuffle(Instruction *Ins);
224+
/// @brief Packetize a sub-group shuffle-xor builtin
225+
///
226+
/// Note - not any shuffle-like operation, but specifically the 'shuffle_xor'
227+
/// builtin.
228+
///
229+
/// @param[in] Ins Instruction to packetize.
230+
/// @param[in] ShuffleXor Shuffle to packetize.
231+
///
232+
/// @return Packetized instructions.
233+
Result packetizeSubgroupShuffleXor(
234+
Instruction *Ins, compiler::utils::GroupCollective ShuffleXor);
224235

225236
/// @brief Packetize PHI node.
226237
///
@@ -926,10 +937,19 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
926937
}
927938

928939
if (auto shuffle = isSubgroupShuffleLike(Ins)) {
929-
if (shuffle->Op == compiler::utils::GroupCollective::OpKind::Shuffle) {
930-
if (auto *s = packetizeSubgroupShuffle(Ins)) {
931-
return broadcast(s);
932-
}
940+
switch (shuffle->Op) {
941+
default:
942+
break;
943+
case compiler::utils::GroupCollective::OpKind::Shuffle:
944+
if (auto *s = packetizeSubgroupShuffle(Ins)) {
945+
return broadcast(s);
946+
}
947+
break;
948+
case compiler::utils::GroupCollective::OpKind::ShuffleXor:
949+
if (auto s = packetizeSubgroupShuffleXor(Ins, *shuffle)) {
950+
return s;
951+
}
952+
break;
933953
}
934954
// We can't packetize all sub-group shuffle-like operations, but we also
935955
// can't vectorize or instantiate them - so provide a diagnostic saying as
@@ -1414,6 +1434,161 @@ Value *Packetizer::Impl::packetizeSubgroupShuffle(Instruction *I) {
14141434
return CI;
14151435
}
14161436

1437+
Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleXor(
1438+
Instruction *I, compiler::utils::GroupCollective ShuffleXor) {
1439+
auto *const CI = cast<CallInst>(I);
1440+
1441+
// We don't support scalable vectorization of sub-group shuffles.
1442+
if (SimdWidth.isScalable()) {
1443+
return Packetizer::Result(*this);
1444+
}
1445+
unsigned const VF = SimdWidth.getFixedValue();
1446+
1447+
auto *const Data = CI->getArgOperand(0);
1448+
auto *const Val = CI->getArgOperand(1);
1449+
1450+
auto PackData = packetize(Data);
1451+
if (!PackData) {
1452+
return Packetizer::Result(*this);
1453+
}
1454+
1455+
// If the data operand happened to be a broadcast value already, we can use
1456+
// it directly.
1457+
if (PackData.info->numInstances == 0) {
1458+
IC.deleteInstructionLater(CI);
1459+
CI->replaceAllUsesWith(Data);
1460+
return PackData;
1461+
}
1462+
1463+
auto PackVal = packetize(Val);
1464+
if (!PackVal) {
1465+
return Packetizer::Result(*this);
1466+
}
1467+
1468+
// With the packetize operands in place, we have to perform the actual
1469+
// shuffling operation. Since we are one layer higher than the mux
1470+
// sub-groups, our IDs do not easily translate to the mux level. Therefore we
1471+
// perform each shuffle using the regular 'shuffle' and do the XOR of the IDs
1472+
// ourselves.
1473+
1474+
// Note: in this illustrative example, imagine two invocations across a
1475+
// single mux sub-groups, each being vectorized by 4; in other words, 8
1476+
// 'original' invocations to a sub-group, running in two vectorized
1477+
// invocations. Imagine value = 5:
1478+
// | shuffle(X, 5) | shuffle(A, 5) |
1479+
// VF=4 |----------------------|----------------------|
1480+
// | s(<X,Y,Z,W>, 5) | s(<A,B,C,D>, 5) |
1481+
// SG IDs | 0,1,2,3 | 4,5,6,7 |
1482+
// SG IDs^5 | 5,4,7,6 | 1,0,3,2 |
1483+
// I=(SG IDs^5)/4 | 1,1,1,1 | 0,0,0,0 |
1484+
// J=(SG IDs^5)%4 | 1,0,3,2 | 1,0,3,2 |
1485+
// <X,Y,Z,W>[J] | Y,X,W,Z | B,A,D,A |
1486+
// Mux-shuffle[I] | [Y,B][1],[X,A][1],.. | [Y,B][0],[X,A][1],.. |
1487+
// | B,A,D,A | Y,X,W,Z |
1488+
IRBuilder<> B(CI);
1489+
1490+
auto *const SubgroupLocalIDFn = Ctx.builtins().getOrDeclareMuxBuiltin(
1491+
compiler::utils::eMuxBuiltinGetSubGroupLocalId, *F.getParent(),
1492+
{CI->getType()});
1493+
assert(SubgroupLocalIDFn);
1494+
1495+
auto *const SubgroupLocalID =
1496+
B.CreateCall(SubgroupLocalIDFn, {}, "sg.local.id");
1497+
auto const Builtin =
1498+
Ctx.builtins().analyzeBuiltinCall(*SubgroupLocalID, Dimension);
1499+
1500+
// Vectorize the sub-group local ID
1501+
auto *const VecSubgroupLocalID =
1502+
vectorizeWorkGroupCall(SubgroupLocalID, Builtin);
1503+
if (!VecSubgroupLocalID) {
1504+
return Packetizer::Result(*this);
1505+
}
1506+
VecSubgroupLocalID->setName("vec.sg.local.id");
1507+
1508+
// The value is always i32, as is the sub-group local ID. Vectorizing both of
1509+
// them should result in the same vector type, with as many elements as the
1510+
// vectorization factor.
1511+
auto *const VecVal = PackVal.getAsValue();
1512+
1513+
assert(VecVal->getType() == VecSubgroupLocalID->getType() &&
1514+
VecVal->getType()->isVectorTy() &&
1515+
cast<VectorType>(VecVal->getType())
1516+
->getElementCount()
1517+
.getKnownMinValue() == VF &&
1518+
"Unexpected vectorization of sub-group shuffle xor");
1519+
1520+
// Perform the XOR of the sub-group IDs with the 'value', as per the
1521+
// semantics of the builtin.
1522+
auto *const XoredID = B.CreateXor(VecSubgroupLocalID, VecVal);
1523+
1524+
// We need to sanitize the input index so that it stays within the range of
1525+
// one vectorized group.
1526+
auto *const VecIdxFactor = ConstantInt::get(SubgroupLocalID->getType(), VF);
1527+
1528+
// Bring this ID into the range of 'mux' sub-groups by dividing it by the
1529+
// vector size.
1530+
auto *const MuxXoredID =
1531+
B.CreateUDiv(XoredID, B.CreateVectorSplat(VF, VecIdxFactor));
1532+
// And into the range of the vector group
1533+
auto *const VecXoredID =
1534+
B.CreateURem(XoredID, B.CreateVectorSplat(VF, VecIdxFactor));
1535+
1536+
// Now we defer to an *exclusive* scan over the group.
1537+
auto RegularShuffle = ShuffleXor;
1538+
RegularShuffle.Op = compiler::utils::GroupCollective::OpKind::Shuffle;
1539+
1540+
auto RegularShuffleID = Ctx.builtins().getMuxGroupCollective(RegularShuffle);
1541+
assert(RegularShuffleID != compiler::utils::eBuiltinInvalid);
1542+
1543+
auto *const RegularShuffleFn = Ctx.builtins().getOrDeclareMuxBuiltin(
1544+
RegularShuffleID, *F.getParent(), {CI->getType()});
1545+
assert(RegularShuffleFn);
1546+
1547+
auto *const VecData = PackData.getAsValue();
1548+
Value *CombinedShuffle = UndefValue::get(VecData->getType());
1549+
1550+
for (unsigned i = 0; i < VF; i++) {
1551+
auto *Idx = B.getInt32(i);
1552+
// Get the XORd index local to the vector group that this vector group
1553+
// element wants to shuffle with.
1554+
auto *const VecGroupIdx = B.CreateExtractElement(VecXoredID, Idx);
1555+
// Grab that element. It may be a vector, in which case we must extract
1556+
// each element individually.
1557+
Value *DataElt = nullptr;
1558+
if (auto *DataVecTy = dyn_cast<VectorType>(Data->getType()); !DataVecTy) {
1559+
DataElt = B.CreateExtractElement(VecData, VecGroupIdx);
1560+
} else {
1561+
DataElt = UndefValue::get(DataVecTy);
1562+
auto VecWidth = DataVecTy->getElementCount().getFixedValue();
1563+
// VecGroupIdx is the 'base' of the subvector, whose elements are stored
1564+
// sequentially from that point.
1565+
auto *const VecVecGroupIdx =
1566+
B.CreateMul(VecGroupIdx, B.getInt32(VecWidth));
1567+
for (unsigned j = 0; j != VecWidth; j++) {
1568+
auto *const Elt = B.CreateExtractElement(
1569+
VecData, B.CreateAdd(VecVecGroupIdx, B.getInt32(j)));
1570+
DataElt = B.CreateInsertElement(DataElt, Elt, B.getInt32(j));
1571+
}
1572+
}
1573+
assert(DataElt);
1574+
// Shuffle it across the mux sub-group.
1575+
auto *const MuxID = B.CreateExtractElement(MuxXoredID, Idx);
1576+
auto *const Shuff = B.CreateCall(RegularShuffleFn, {DataElt, MuxID});
1577+
// Combine that back into the final shuffled vector.
1578+
if (auto *DataVecTy = dyn_cast<VectorType>(Data->getType()); !DataVecTy) {
1579+
CombinedShuffle = B.CreateInsertElement(CombinedShuffle, Shuff, Idx);
1580+
} else {
1581+
auto VecWidth = DataVecTy->getElementCount().getFixedValue();
1582+
CombinedShuffle = B.CreateInsertVector(
1583+
CombinedShuffle->getType(), CombinedShuffle, Shuff,
1584+
B.getInt64(static_cast<uint64_t>(i) * VecWidth));
1585+
}
1586+
}
1587+
1588+
IC.deleteInstructionLater(CI);
1589+
return assign(CI, CombinedShuffle);
1590+
}
1591+
14171592
Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) {
14181593
if (auto memop = MemOp::get(I)) {
14191594
auto *const mask = memop->getMaskOperand();

0 commit comments

Comments
 (0)