@@ -221,6 +221,17 @@ class Packetizer::Impl : public Packetizer {
221221  // /
222222  // / @return Packetized instructions.
223223  Value *packetizeSubgroupShuffle (Instruction *Ins);
224+   // / @brief Packetize a sub-group shuffle-xor builtin
225+   // /
226+   // / Note - not any shuffle-like operation, but specifically the 'shuffle_xor'
227+   // / builtin.
228+   // /
229+   // / @param[in] Ins Instruction to packetize.
230+   // / @param[in] ShuffleXor Shuffle to packetize.
231+   // /
232+   // / @return Packetized instructions.
233+   Result packetizeSubgroupShuffleXor (
234+       Instruction *Ins, compiler::utils::GroupCollective ShuffleXor);
224235
225236  // / @brief Packetize PHI node.
226237  // /
@@ -926,10 +937,19 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
926937  }
927938
928939  if  (auto  shuffle = isSubgroupShuffleLike (Ins)) {
929-     if  (shuffle->Op  == compiler::utils::GroupCollective::OpKind::Shuffle) {
930-       if  (auto  *s = packetizeSubgroupShuffle (Ins)) {
931-         return  broadcast (s);
932-       }
940+     switch  (shuffle->Op ) {
941+       default :
942+         break ;
943+       case  compiler::utils::GroupCollective::OpKind::Shuffle:
944+         if  (auto  *s = packetizeSubgroupShuffle (Ins)) {
945+           return  broadcast (s);
946+         }
947+         break ;
948+       case  compiler::utils::GroupCollective::OpKind::ShuffleXor:
949+         if  (auto  s = packetizeSubgroupShuffleXor (Ins, *shuffle)) {
950+           return  s;
951+         }
952+         break ;
933953    }
934954    //  We can't packetize all sub-group shuffle-like operations, but we also
935955    //  can't vectorize or instantiate them - so provide a diagnostic saying as
@@ -1414,6 +1434,161 @@ Value *Packetizer::Impl::packetizeSubgroupShuffle(Instruction *I) {
14141434  return  CI;
14151435}
14161436
1437+ Packetizer::Result Packetizer::Impl::packetizeSubgroupShuffleXor (
1438+     Instruction *I, compiler::utils::GroupCollective ShuffleXor) {
1439+   auto  *const  CI = cast<CallInst>(I);
1440+ 
1441+   //  We don't support scalable vectorization of sub-group shuffles.
1442+   if  (SimdWidth.isScalable ()) {
1443+     return  Packetizer::Result (*this );
1444+   }
1445+   unsigned  const  VF = SimdWidth.getFixedValue ();
1446+ 
1447+   auto  *const  Data = CI->getArgOperand (0 );
1448+   auto  *const  Val = CI->getArgOperand (1 );
1449+ 
1450+   auto  PackData = packetize (Data);
1451+   if  (!PackData) {
1452+     return  Packetizer::Result (*this );
1453+   }
1454+ 
1455+   //  If the data operand happened to be a broadcast value already, we can use
1456+   //  it directly.
1457+   if  (PackData.info ->numInstances  == 0 ) {
1458+     IC.deleteInstructionLater (CI);
1459+     CI->replaceAllUsesWith (Data);
1460+     return  PackData;
1461+   }
1462+ 
1463+   auto  PackVal = packetize (Val);
1464+   if  (!PackVal) {
1465+     return  Packetizer::Result (*this );
1466+   }
1467+ 
1468+   //  With the packetize operands in place, we have to perform the actual
1469+   //  shuffling operation. Since we are one layer higher than the mux
1470+   //  sub-groups, our IDs do not easily translate to the mux level. Therefore we
1471+   //  perform each shuffle using the regular 'shuffle' and do the XOR of the IDs
1472+   //  ourselves.
1473+ 
1474+   //  Note: in this illustrative example, imagine two invocations across a
1475+   //  single mux sub-groups, each being vectorized by 4; in other words, 8
1476+   //  'original' invocations to a sub-group, running in two vectorized
1477+   //  invocations. Imagine value = 5:
1478+   //                 |  shuffle(X, 5)       |  shuffle(A, 5)       |
1479+   //  VF=4           |----------------------|----------------------|
1480+   //                 |    s(<X,Y,Z,W>, 5)   |    s(<A,B,C,D>, 5)   |
1481+   //  SG IDs         |       0,1,2,3        |       4,5,6,7        |
1482+   //  SG IDs^5       |       5,4,7,6        |       1,0,3,2        |
1483+   //  I=(SG IDs^5)/4 |       1,1,1,1        |       0,0,0,0        |
1484+   //  J=(SG IDs^5)%4 |       1,0,3,2        |       1,0,3,2        |
1485+   //  <X,Y,Z,W>[J]   |       Y,X,W,Z        |       B,A,D,A        |
1486+   //  Mux-shuffle[I] | [Y,B][1],[X,A][1],.. | [Y,B][0],[X,A][1],.. |
1487+   //                 |       B,A,D,A        |       Y,X,W,Z        |
1488+   IRBuilder<> B (CI);
1489+ 
1490+   auto  *const  SubgroupLocalIDFn = Ctx.builtins ().getOrDeclareMuxBuiltin (
1491+       compiler::utils::eMuxBuiltinGetSubGroupLocalId, *F.getParent (),
1492+       {CI->getType ()});
1493+   assert (SubgroupLocalIDFn);
1494+ 
1495+   auto  *const  SubgroupLocalID =
1496+       B.CreateCall (SubgroupLocalIDFn, {}, " sg.local.id" 
1497+   auto  const  Builtin =
1498+       Ctx.builtins ().analyzeBuiltinCall (*SubgroupLocalID, Dimension);
1499+ 
1500+   //  Vectorize the sub-group local ID
1501+   auto  *const  VecSubgroupLocalID =
1502+       vectorizeWorkGroupCall (SubgroupLocalID, Builtin);
1503+   if  (!VecSubgroupLocalID) {
1504+     return  Packetizer::Result (*this );
1505+   }
1506+   VecSubgroupLocalID->setName (" vec.sg.local.id" 
1507+ 
1508+   //  The value is always i32, as is the sub-group local ID. Vectorizing both of
1509+   //  them should result in the same vector type, with as many elements as the
1510+   //  vectorization factor.
1511+   auto  *const  VecVal = PackVal.getAsValue ();
1512+ 
1513+   assert (VecVal->getType () == VecSubgroupLocalID->getType () &&
1514+          VecVal->getType ()->isVectorTy () &&
1515+          cast<VectorType>(VecVal->getType ())
1516+                  ->getElementCount ()
1517+                  .getKnownMinValue () == VF &&
1518+          " Unexpected vectorization of sub-group shuffle xor" 
1519+ 
1520+   //  Perform the XOR of the sub-group IDs with the 'value', as per the
1521+   //  semantics of the builtin.
1522+   auto  *const  XoredID = B.CreateXor (VecSubgroupLocalID, VecVal);
1523+ 
1524+   //  We need to sanitize the input index so that it stays within the range of
1525+   //  one vectorized group.
1526+   auto  *const  VecIdxFactor = ConstantInt::get (SubgroupLocalID->getType (), VF);
1527+ 
1528+   //  Bring this ID into the range of 'mux' sub-groups by dividing it by the
1529+   //  vector size.
1530+   auto  *const  MuxXoredID =
1531+       B.CreateUDiv (XoredID, B.CreateVectorSplat (VF, VecIdxFactor));
1532+   //  And into the range of the vector group
1533+   auto  *const  VecXoredID =
1534+       B.CreateURem (XoredID, B.CreateVectorSplat (VF, VecIdxFactor));
1535+ 
1536+   //  Now we defer to an *exclusive* scan over the group.
1537+   auto  RegularShuffle = ShuffleXor;
1538+   RegularShuffle.Op  = compiler::utils::GroupCollective::OpKind::Shuffle;
1539+ 
1540+   auto  RegularShuffleID = Ctx.builtins ().getMuxGroupCollective (RegularShuffle);
1541+   assert (RegularShuffleID != compiler::utils::eBuiltinInvalid);
1542+ 
1543+   auto  *const  RegularShuffleFn = Ctx.builtins ().getOrDeclareMuxBuiltin (
1544+       RegularShuffleID, *F.getParent (), {CI->getType ()});
1545+   assert (RegularShuffleFn);
1546+ 
1547+   auto  *const  VecData = PackData.getAsValue ();
1548+   Value *CombinedShuffle = UndefValue::get (VecData->getType ());
1549+ 
1550+   for  (unsigned  i = 0 ; i < VF; i++) {
1551+     auto  *Idx = B.getInt32 (i);
1552+     //  Get the XORd index local to the vector group that this vector group
1553+     //  element wants to shuffle with.
1554+     auto  *const  VecGroupIdx = B.CreateExtractElement (VecXoredID, Idx);
1555+     //  Grab that element. It may be a vector, in which case we must extract
1556+     //  each element individually.
1557+     Value *DataElt = nullptr ;
1558+     if  (auto  *DataVecTy = dyn_cast<VectorType>(Data->getType ()); !DataVecTy) {
1559+       DataElt = B.CreateExtractElement (VecData, VecGroupIdx);
1560+     } else  {
1561+       DataElt = UndefValue::get (DataVecTy);
1562+       auto  VecWidth = DataVecTy->getElementCount ().getFixedValue ();
1563+       //  VecGroupIdx is the 'base' of the subvector, whose elements are stored
1564+       //  sequentially from that point.
1565+       auto  *const  VecVecGroupIdx =
1566+           B.CreateMul (VecGroupIdx, B.getInt32 (VecWidth));
1567+       for  (unsigned  j = 0 ; j != VecWidth; j++) {
1568+         auto  *const  Elt = B.CreateExtractElement (
1569+             VecData, B.CreateAdd (VecVecGroupIdx, B.getInt32 (j)));
1570+         DataElt = B.CreateInsertElement (DataElt, Elt, B.getInt32 (j));
1571+       }
1572+     }
1573+     assert (DataElt);
1574+     //  Shuffle it across the mux sub-group.
1575+     auto  *const  MuxID = B.CreateExtractElement (MuxXoredID, Idx);
1576+     auto  *const  Shuff = B.CreateCall (RegularShuffleFn, {DataElt, MuxID});
1577+     //  Combine that back into the final shuffled vector.
1578+     if  (auto  *DataVecTy = dyn_cast<VectorType>(Data->getType ()); !DataVecTy) {
1579+       CombinedShuffle = B.CreateInsertElement (CombinedShuffle, Shuff, Idx);
1580+     } else  {
1581+       auto  VecWidth = DataVecTy->getElementCount ().getFixedValue ();
1582+       CombinedShuffle = B.CreateInsertVector (
1583+           CombinedShuffle->getType (), CombinedShuffle, Shuff,
1584+           B.getInt64 (static_cast <uint64_t >(i) * VecWidth));
1585+     }
1586+   }
1587+ 
1588+   IC.deleteInstructionLater (CI);
1589+   return  assign (CI, CombinedShuffle);
1590+ }
1591+ 
14171592Value *Packetizer::Impl::packetizeMaskVarying (Instruction *I) {
14181593  if  (auto  memop = MemOp::get (I)) {
14191594    auto  *const  mask = memop->getMaskOperand ();
0 commit comments