@@ -221,6 +221,17 @@ class Packetizer::Impl : public Packetizer {
221221 // /
222222 // / @return Packetized instructions.
223223 Value *packetizeSubgroupShuffle (Instruction *Ins);
224+ // / @brief Packetize a sub-group shuffle-xor builtin
225+ // /
226+ // / Note - not any shuffle-like operation, but specifically the 'shuffle_xor'
227+ // / builtin.
228+ // /
229+ // / @param[in] Ins Instruction to packetize.
230+ // / @param[in] ShuffleXor Shuffle to packetize.
231+ // /
232+ // / @return Packetized instructions.
233+ Result packetizeSubgroupShuffleXor (
234+ Instruction *Ins, compiler::utils::GroupCollective ShuffleXor);
224235
225236 // / @brief Packetize PHI node.
226237 // /
@@ -926,10 +937,19 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
926937 }
927938
928939 if (auto shuffle = isSubgroupShuffleLike (Ins)) {
929- if (shuffle->Op == compiler::utils::GroupCollective::OpKind::Shuffle) {
930- if (auto *s = packetizeSubgroupShuffle (Ins)) {
931- return broadcast (s);
932- }
940+ switch (shuffle->Op ) {
941+ default :
942+ break ;
943+ case compiler::utils::GroupCollective::OpKind::Shuffle:
944+ if (auto *s = packetizeSubgroupShuffle (Ins)) {
945+ return broadcast (s);
946+ }
947+ break ;
948+ case compiler::utils::GroupCollective::OpKind::ShuffleXor:
949+ if (auto s = packetizeSubgroupShuffleXor (Ins, *shuffle)) {
950+ return s;
951+ }
952+ break ;
933953 }
934954 // We can't packetize all sub-group shuffle-like operations, but we also
935955 // can't vectorize or instantiate them - so provide a diagnostic saying as
@@ -1414,6 +1434,161 @@ Value *Packetizer::Impl::packetizeSubgroupShuffle(Instruction *I) {
14141434 return CI;
14151435}
14161436
1437+ Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleXor (
1438+ Instruction *I, compiler::utils::GroupCollective ShuffleXor) {
1439+ auto *const CI = cast<CallInst>(I);
1440+
1441+ // We don't support scalable vectorization of sub-group shuffles.
1442+ if (SimdWidth.isScalable ()) {
1443+ return Packetizer::Result (*this );
1444+ }
1445+ unsigned const VF = SimdWidth.getFixedValue ();
1446+
1447+ auto *const Data = CI->getArgOperand (0 );
1448+ auto *const Val = CI->getArgOperand (1 );
1449+
1450+ auto PackData = packetize (Data);
1451+ if (!PackData) {
1452+ return Packetizer::Result (*this );
1453+ }
1454+
1455+ // If the data operand happened to be a broadcast value already, we can use
1456+ // it directly.
1457+ if (PackData.info ->numInstances == 0 ) {
1458+ IC.deleteInstructionLater (CI);
1459+ CI->replaceAllUsesWith (Data);
1460+ return PackData;
1461+ }
1462+
1463+ auto PackVal = packetize (Val);
1464+ if (!PackVal) {
1465+ return Packetizer::Result (*this );
1466+ }
1467+
1468+ // With the packetize operands in place, we have to perform the actual
1469+ // shuffling operation. Since we are one layer higher than the mux
1470+ // sub-groups, our IDs do not easily translate to the mux level. Therefore we
1471+ // perform each shuffle using the regular 'shuffle' and do the XOR of the IDs
1472+ // ourselves.
1473+
1474+ // Note: in this illustrative example, imagine two invocations across a
1475+ // single mux sub-groups, each being vectorized by 4; in other words, 8
1476+ // 'original' invocations to a sub-group, running in two vectorized
1477+ // invocations. Imagine value = 5:
1478+ // | shuffle(X, 5) | shuffle(A, 5) |
1479+ // VF=4 |----------------------|----------------------|
1480+ // | s(<X,Y,Z,W>, 5) | s(<A,B,C,D>, 5) |
1481+ // SG IDs | 0,1,2,3 | 4,5,6,7 |
1482+ // SG IDs^5 | 5,4,7,6 | 1,0,3,2 |
1483+ // I=(SG IDs^5)/4 | 1,1,1,1 | 0,0,0,0 |
1484+ // J=(SG IDs^5)%4 | 1,0,3,2 | 1,0,3,2 |
1485+ // <X,Y,Z,W>[J] | Y,X,W,Z | B,A,D,A |
1486+ // Mux-shuffle[I] | [Y,B][1],[X,A][1],.. | [Y,B][0],[X,A][1],.. |
1487+ // | B,A,D,A | Y,X,W,Z |
1488+ IRBuilder<> B (CI);
1489+
1490+ auto *const SubgroupLocalIDFn = Ctx.builtins ().getOrDeclareMuxBuiltin (
1491+ compiler::utils::eMuxBuiltinGetSubGroupLocalId, *F.getParent (),
1492+ {CI->getType ()});
1493+ assert (SubgroupLocalIDFn);
1494+
1495+ auto *const SubgroupLocalID =
1496+ B.CreateCall (SubgroupLocalIDFn, {}, " sg.local.id" );
1497+ auto const Builtin =
1498+ Ctx.builtins ().analyzeBuiltinCall (*SubgroupLocalID, Dimension);
1499+
1500+ // Vectorize the sub-group local ID
1501+ auto *const VecSubgroupLocalID =
1502+ vectorizeWorkGroupCall (SubgroupLocalID, Builtin);
1503+ if (!VecSubgroupLocalID) {
1504+ return Packetizer::Result (*this );
1505+ }
1506+ VecSubgroupLocalID->setName (" vec.sg.local.id" );
1507+
1508+ // The value is always i32, as is the sub-group local ID. Vectorizing both of
1509+ // them should result in the same vector type, with as many elements as the
1510+ // vectorization factor.
1511+ auto *const VecVal = PackVal.getAsValue ();
1512+
1513+ assert (VecVal->getType () == VecSubgroupLocalID->getType () &&
1514+ VecVal->getType ()->isVectorTy () &&
1515+ cast<VectorType>(VecVal->getType ())
1516+ ->getElementCount ()
1517+ .getKnownMinValue () == VF &&
1518+ " Unexpected vectorization of sub-group shuffle xor" );
1519+
1520+ // Perform the XOR of the sub-group IDs with the 'value', as per the
1521+ // semantics of the builtin.
1522+ auto *const XoredID = B.CreateXor (VecSubgroupLocalID, VecVal);
1523+
1524+ // We need to sanitize the input index so that it stays within the range of
1525+ // one vectorized group.
1526+ auto *const VecIdxFactor = ConstantInt::get (SubgroupLocalID->getType (), VF);
1527+
1528+ // Bring this ID into the range of 'mux' sub-groups by dividing it by the
1529+ // vector size.
1530+ auto *const MuxXoredID =
1531+ B.CreateUDiv (XoredID, B.CreateVectorSplat (VF, VecIdxFactor));
1532+ // And into the range of the vector group
1533+ auto *const VecXoredID =
1534+ B.CreateURem (XoredID, B.CreateVectorSplat (VF, VecIdxFactor));
1535+
1536+ // Now we defer to an *exclusive* scan over the group.
1537+ auto RegularShuffle = ShuffleXor;
1538+ RegularShuffle.Op = compiler::utils::GroupCollective::OpKind::Shuffle;
1539+
1540+ auto RegularShuffleID = Ctx.builtins ().getMuxGroupCollective (RegularShuffle);
1541+ assert (RegularShuffleID != compiler::utils::eBuiltinInvalid);
1542+
1543+ auto *const RegularShuffleFn = Ctx.builtins ().getOrDeclareMuxBuiltin (
1544+ RegularShuffleID, *F.getParent (), {CI->getType ()});
1545+ assert (RegularShuffleFn);
1546+
1547+ auto *const VecData = PackData.getAsValue ();
1548+ Value *CombinedShuffle = UndefValue::get (VecData->getType ());
1549+
1550+ for (unsigned i = 0 ; i < VF; i++) {
1551+ auto *Idx = B.getInt32 (i);
1552+ // Get the XORd index local to the vector group that this vector group
1553+ // element wants to shuffle with.
1554+ auto *const VecGroupIdx = B.CreateExtractElement (VecXoredID, Idx);
1555+ // Grab that element. It may be a vector, in which case we must extract
1556+ // each element individually.
1557+ Value *DataElt = nullptr ;
1558+ if (auto *DataVecTy = dyn_cast<VectorType>(Data->getType ()); !DataVecTy) {
1559+ DataElt = B.CreateExtractElement (VecData, VecGroupIdx);
1560+ } else {
1561+ DataElt = UndefValue::get (DataVecTy);
1562+ auto VecWidth = DataVecTy->getElementCount ().getFixedValue ();
1563+ // VecGroupIdx is the 'base' of the subvector, whose elements are stored
1564+ // sequentially from that point.
1565+ auto *const VecVecGroupIdx =
1566+ B.CreateMul (VecGroupIdx, B.getInt32 (VecWidth));
1567+ for (unsigned j = 0 ; j != VecWidth; j++) {
1568+ auto *const Elt = B.CreateExtractElement (
1569+ VecData, B.CreateAdd (VecVecGroupIdx, B.getInt32 (j)));
1570+ DataElt = B.CreateInsertElement (DataElt, Elt, B.getInt32 (j));
1571+ }
1572+ }
1573+ assert (DataElt);
1574+ // Shuffle it across the mux sub-group.
1575+ auto *const MuxID = B.CreateExtractElement (MuxXoredID, Idx);
1576+ auto *const Shuff = B.CreateCall (RegularShuffleFn, {DataElt, MuxID});
1577+ // Combine that back into the final shuffled vector.
1578+ if (auto *DataVecTy = dyn_cast<VectorType>(Data->getType ()); !DataVecTy) {
1579+ CombinedShuffle = B.CreateInsertElement (CombinedShuffle, Shuff, Idx);
1580+ } else {
1581+ auto VecWidth = DataVecTy->getElementCount ().getFixedValue ();
1582+ CombinedShuffle = B.CreateInsertVector (
1583+ CombinedShuffle->getType (), CombinedShuffle, Shuff,
1584+ B.getInt64 (static_cast <uint64_t >(i) * VecWidth));
1585+ }
1586+ }
1587+
1588+ IC.deleteInstructionLater (CI);
1589+ return assign (CI, CombinedShuffle);
1590+ }
1591+
14171592Value *Packetizer::Impl::packetizeMaskVarying (Instruction *I) {
14181593 if (auto memop = MemOp::get (I)) {
14191594 auto *const mask = memop->getMaskOperand ();
0 commit comments