|
17 | 17 | #include "transform/packetizer.h" |
18 | 18 |
|
19 | 19 | #include <compiler/utils/builtin_info.h> |
| 20 | +#include <compiler/utils/group_collective_helpers.h> |
20 | 21 | #include <compiler/utils/mangling.h> |
21 | 22 | #include <llvm/ADT/DepthFirstIterator.h> |
22 | 23 | #include <llvm/ADT/SmallPtrSet.h> |
23 | 24 | #include <llvm/ADT/Statistic.h> |
24 | 25 | #include <llvm/ADT/Twine.h> |
25 | 26 | #include <llvm/Analysis/LoopInfo.h> |
26 | 27 | #include <llvm/Analysis/VectorUtils.h> |
| 28 | +#include <llvm/IR/Constants.h> |
27 | 29 | #include <llvm/IR/DIBuilder.h> |
28 | 30 | #include <llvm/IR/DebugInfoMetadata.h> |
29 | 31 | #include <llvm/IR/IRBuilder.h> |
@@ -202,13 +204,24 @@ class Packetizer::Impl : public Packetizer { |
202 | 204 | /// |
203 | 205 | /// @return Packetized instruction. |
204 | 206 | Value *packetizeGroupBroadcast(Instruction *I); |
205 | | - /// @brief Returns true if the instruction is a subgroup shuffle. |
| 207 | + /// @brief Returns true if the instruction is any subgroup shuffle. |
206 | 208 | /// |
207 | 209 | /// @param[in] I Instruction to query. |
208 | 210 | /// |
209 | | - /// @return True if the instruction is a call to a mux subgroup shuffle |
| 211 | + /// @return The group collective data if the instruction is a call to any of |
| 212 | + /// the mux subgroup shuffle builtins; std::nullopt otherwise. |
| 213 | + std::optional<compiler::utils::GroupCollective> isSubgroupShuffleLike( |
| 214 | + Instruction *I); |
| 215 | + /// @brief Packetize a sub-group shuffle builtin |
| 216 | + /// |
| 217 | + /// Note - not any shuffle-like operation, but specifically the 'shuffle' |
210 | 218 | /// builtin. |
211 | | - bool isSubgroupShuffle(Instruction *I); |
| 219 | + /// |
| 220 | + /// @param[in] Ins Instruction to packetize. |
| 221 | + /// |
| 222 | + /// @return Packetized instructions. |
| 223 | + Value *packetizeSubgroupShuffle(Instruction *Ins); |
| 224 | + |
212 | 225 | /// @brief Packetize PHI node. |
213 | 226 | /// |
214 | 227 | /// @param[in] Phi PHI Node to packetize. |
@@ -861,12 +874,6 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) { |
861 | 874 |
|
862 | 875 | auto *const Ins = cast<Instruction>(V); |
863 | 876 |
|
864 | | - // FIXME: Add support for vectorizing sub-group shuffles |
865 | | - if (isSubgroupShuffle(Ins)) { |
866 | | - emitVeczRemarkMissed(&F, Ins, "Could not packetize sub-group shuffle"); |
867 | | - return Packetizer::Result(*this); |
868 | | - } |
869 | | - |
870 | 877 | if (auto *const Branch = dyn_cast<BranchInst>(Ins)) { |
871 | 878 | if (Branch->isConditional()) { |
872 | 879 | // varying reductions need to be packetized |
@@ -918,6 +925,19 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) { |
918 | 925 | return broadcast(brdcast); |
919 | 926 | } |
920 | 927 |
|
| 928 | + if (auto shuffle = isSubgroupShuffleLike(Ins)) { |
| 929 | + if (shuffle->Op == compiler::utils::GroupCollective::OpKind::Shuffle) { |
| 930 | + if (auto *s = packetizeSubgroupShuffle(Ins)) { |
| 931 | + return broadcast(s); |
| 932 | + } |
| 933 | + } |
| 934 | + // We can't packetize all sub-group shuffle-like operations, but we also |
| 935 | + // can't vectorize or instantiate them - so provide a diagnostic saying as |
| 936 | + // much. |
| 937 | + emitVeczRemarkMissed(&F, Ins, "Could not packetize sub-group shuffle"); |
| 938 | + return Packetizer::Result(*this); |
| 939 | + } |
| 940 | + |
921 | 941 | // Check if we should broadcast the instruction. |
922 | 942 | // Broadcast uniform instructions, unless we want to packetize uniform |
923 | 943 | // instructions as well. We can assume that isMaskVarying is false at this |
@@ -1265,18 +1285,133 @@ Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) { |
1265 | 1285 | return CI; |
1266 | 1286 | } |
1267 | 1287 |
|
1268 | | -bool Packetizer::Impl::isSubgroupShuffle(Instruction *I) { |
| 1288 | +std::optional<compiler::utils::GroupCollective> |
| 1289 | +Packetizer::Impl::isSubgroupShuffleLike(Instruction *I) { |
1269 | 1290 | auto *const CI = dyn_cast<CallInst>(I); |
1270 | 1291 | if (!CI || !CI->getCalledFunction()) { |
1271 | | - return false; |
| 1292 | + return std::nullopt; |
1272 | 1293 | } |
1273 | 1294 | compiler::utils::BuiltinInfo &BI = Ctx.builtins(); |
1274 | 1295 | Function *callee = CI->getCalledFunction(); |
1275 | 1296 |
|
1276 | 1297 | auto const Builtin = BI.analyzeBuiltin(*callee); |
1277 | 1298 | auto const Info = BI.isMuxGroupCollective(Builtin.ID); |
1278 | 1299 |
|
1279 | | - return Info && Info->isSubGroupScope() && Info->isShuffleLike(); |
| 1300 | + if (Info && Info->isSubGroupScope() && Info->isShuffleLike()) { |
| 1301 | + return Info; |
| 1302 | + } |
| 1303 | + |
| 1304 | + return std::nullopt; |
| 1305 | +} |
| 1306 | + |
| 1307 | +Value *Packetizer::Impl::packetizeSubgroupShuffle(Instruction *I) { |
| 1308 | + auto *const CI = cast<CallInst>(I); |
| 1309 | + |
| 1310 | + // We don't support scalable vectorization of sub-group shuffles. |
| 1311 | + if (SimdWidth.isScalable()) { |
| 1312 | + return nullptr; |
| 1313 | + } |
| 1314 | + |
| 1315 | + auto *const Data = CI->getArgOperand(0); |
| 1316 | + auto *const Idx = CI->getArgOperand(1); |
| 1317 | + |
| 1318 | + auto PackData = packetize(Data); |
| 1319 | + if (!PackData) { |
| 1320 | + return nullptr; |
| 1321 | + } |
| 1322 | + |
| 1323 | + // If the data operand happened to be a broadcast value already, we can use |
| 1324 | + // it directly. |
| 1325 | + if (PackData.info->numInstances == 0) { |
| 1326 | + IC.deleteInstructionLater(CI); |
| 1327 | + CI->replaceAllUsesWith(Data); |
| 1328 | + return Data; |
| 1329 | + } |
| 1330 | + |
| 1331 | + // We can't packetize varying shuffle indices yet. |
| 1332 | + if (UVR.isVarying(Idx)) { |
| 1333 | + return nullptr; |
| 1334 | + } |
| 1335 | + |
| 1336 | + IRBuilder<> B(CI); |
| 1337 | + |
| 1338 | + // We need to sanitize the input index so that it stays within the range of |
| 1339 | + // one vectorized group. |
| 1340 | + unsigned const VF = SimdWidth.getFixedValue(); |
| 1341 | + auto *const VecIdxFactor = ConstantInt::get(Idx->getType(), VF); |
| 1342 | + // This index is the element of the vector-group which holds the desired |
| 1343 | + // data, per mux sub-group. |
| 1344 | + // <x, y>, <z, w>: idx 1 -> vector element 1, idx 2 -> vector element 0. |
| 1345 | + auto *const VecIdx = B.CreateURem(Idx, VecIdxFactor); |
| 1346 | + // This index is the mux sub-group in which the desired data resides. |
| 1347 | + // <x, y>, <z, w>: idx 1 -> mux sub-group 0, idx 3 -> mux sub-group 1. |
| 1348 | + auto *const MuxIdx = B.CreateUDiv(Idx, VecIdxFactor); |
| 1349 | + |
| 1350 | + Value *VecData = PackData.getAsValue(); |
| 1351 | + |
| 1352 | + // Note: in each illustrative example, imagine two invocations across a |
| 1353 | + // single mux sub-groups, each being vectorized by 4; in other words, 8 |
| 1354 | + // 'original' invocations to a sub-group, running in two vectorized |
| 1355 | + // invocations. |
| 1356 | + if (auto *const DataVecTy = dyn_cast<VectorType>(Data->getType()); |
| 1357 | + !DataVecTy) { |
| 1358 | + // The vectorized shuffle is producing a scalar (assuming uniform indices, |
| 1359 | + // see above). Imagine i=6 (6 % 4 = 2 and 6 / 4 = 1): |
| 1360 | + // | shuffle(X, 6) | shuffle(A, 6) | |
| 1361 | + // VF=4 |-----------------|-----------------| |
| 1362 | + // | s(<X,Y,Z,W>, 2) | s(<A,B,C,D>, 2) | |
| 1363 | + // elt 2 | Z | C | |
| 1364 | + // shuff | shuffle(Z, 1) | shuffle(C, 1) | |
| 1365 | + // | C | C | |
| 1366 | + // bcast | <C,C,C,C> | <C,C,C,C> | |
| 1367 | + // This way we can see how each of the 8 invocations end up with the 6th |
| 1368 | + // element of the total sub-group. |
| 1369 | + VecData = B.CreateExtractElement(VecData, VecIdx, "vec.extract"); |
| 1370 | + } else if (auto *const CIdx = dyn_cast<ConstantInt>(VecIdx)) { |
| 1371 | + // The shuffle produces a vector, and we have a constant shuffle index - we |
| 1372 | + // can extract a subvector easily. |
| 1373 | + // Imagine i=6 (6 % 4 = 2 and 6 / 4 = 1): |
| 1374 | + // | shuffle(<X,Y>, 6) | shuffle(<A,B>, 6) | |
| 1375 | + // VF=4 |-------------------------|-------------------------| |
| 1376 | + // | s(<X,Y,Z,W,P,Q,-,->, 2) | s(<A,B,C,D,E,F,-,->, 2) | |
| 1377 | + // vec 2 | <P,Q> | <E,F> | |
| 1378 | + // shuff | shuffle(<P,Q>, 1) | shuffle(<E,F>, 1) | |
| 1379 | + // | <E,F> | <E,F> | |
| 1380 | + // bcast | <E,F,E,F,E,F,E,F> | <E,F,E,F,E,F,E,F> | |
| 1381 | + // This way we can see how each of the 8 invocations end up with the 6th |
| 1382 | + // element of the total sub-group, which is a two-element vector. |
| 1383 | + |
| 1384 | + // Note: the subvector vector index type has to be i64. Scale it up by the |
| 1385 | + // size of the vector we're extracting: the index is the *element* from |
| 1386 | + // which to extract - it is not implicitly scaled by the vector size. |
| 1387 | + auto *const ExtractIdx = B.getInt64( |
| 1388 | + CIdx->getZExtValue() * DataVecTy->getElementCount().getFixedValue()); |
| 1389 | + VecData = B.CreateExtractVector(Data->getType(), VecData, ExtractIdx, |
| 1390 | + "vec.extract"); |
| 1391 | + } else { |
| 1392 | + // This is as above, but the process of extracting the initial vector is |
| 1393 | + // more complicated - we have to manually extract and insert each element. |
| 1394 | + // It's possible that for some targets and for some combinations of vector |
| 1395 | + // width and vectorization factor, that going through memory would be |
| 1396 | + // faster. |
| 1397 | + Value *ExtractedVec = UndefValue::get(DataVecTy); |
| 1398 | + unsigned const DataNumElts = DataVecTy->getElementCount().getFixedValue(); |
| 1399 | + auto *const BaseIdx = B.CreateMul(VecIdx, B.getInt32(DataNumElts)); |
| 1400 | + for (unsigned i = 0; i < DataNumElts; i++) { |
| 1401 | + auto *const SubIdx = B.CreateAdd(BaseIdx, B.getInt32(i)); |
| 1402 | + auto *const Elt = B.CreateExtractElement(VecData, SubIdx); |
| 1403 | + ExtractedVec = B.CreateInsertElement(ExtractedVec, Elt, B.getInt32(i)); |
| 1404 | + } |
| 1405 | + VecData = ExtractedVec; |
| 1406 | + } |
| 1407 | + |
| 1408 | + // We leave the original shuffle function and divert the vectorized |
| 1409 | + // shuffle through it, giving us a shuffle over the full apparent |
| 1410 | + // sub-group size (vecz * mux). |
| 1411 | + CI->setOperand(0, VecData); |
| 1412 | + CI->setOperand(1, MuxIdx); |
| 1413 | + |
| 1414 | + return CI; |
1280 | 1415 | } |
1281 | 1416 |
|
1282 | 1417 | Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) { |
|
0 commit comments