Skip to content

Commit e73debc

Browse files
committed
[vecz] Packetize sub-group shuffles with uniform indices
This extends fixed-width vectorization capabilities to `__mux_sub_group_shuffle` builtins, but only those with uniform indices (where the shuffle index is the same for all invocations in the sub-group). This accounts for the majority of those tested by the SYCL-CTS. Support for varying indices will come down the line, once the other shuffles are covered under similar conditions. The existing sub-group LIT tests have been split by operation, as they are expected to grow significantly to cover all of the different conditions we can vectorize under.
1 parent e77143e commit e73debc

File tree

6 files changed

+445
-91
lines changed

6 files changed

+445
-91
lines changed

llvm/lib/SYCLNativeCPUUtils/compiler_passes/vecz/source/transform/packetizer.cpp

Lines changed: 147 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
#include "transform/packetizer.h"
1818

1919
#include <compiler/utils/builtin_info.h>
20+
#include <compiler/utils/group_collective_helpers.h>
2021
#include <compiler/utils/mangling.h>
2122
#include <llvm/ADT/DepthFirstIterator.h>
2223
#include <llvm/ADT/SmallPtrSet.h>
2324
#include <llvm/ADT/Statistic.h>
2425
#include <llvm/ADT/Twine.h>
2526
#include <llvm/Analysis/LoopInfo.h>
2627
#include <llvm/Analysis/VectorUtils.h>
28+
#include <llvm/IR/Constants.h>
2729
#include <llvm/IR/DIBuilder.h>
2830
#include <llvm/IR/DebugInfoMetadata.h>
2931
#include <llvm/IR/IRBuilder.h>
@@ -202,13 +204,24 @@ class Packetizer::Impl : public Packetizer {
202204
///
203205
/// @return Packetized instruction.
204206
Value *packetizeGroupBroadcast(Instruction *I);
205-
/// @brief Returns true if the instruction is a subgroup shuffle.
207+
/// @brief Returns true if the instruction is any subgroup shuffle.
206208
///
207209
/// @param[in] I Instruction to query.
208210
///
209-
/// @return True if the instruction is a call to a mux subgroup shuffle
211+
/// @return The group collective data if the instruction is a call to any of
212+
/// the mux subgroup shuffle builtins; std::nullopt otherwise.
213+
std::optional<compiler::utils::GroupCollective> isSubgroupShuffleLike(
214+
Instruction *I);
215+
/// @brief Packetize a sub-group shuffle builtin
216+
///
217+
/// Note - not any shuffle-like operation, but specifically the 'shuffle'
210218
/// builtin.
211-
bool isSubgroupShuffle(Instruction *I);
219+
///
220+
/// @param[in] Ins Instruction to packetize.
221+
///
222+
/// @return Packetized instructions.
223+
Value *packetizeSubgroupShuffle(Instruction *Ins);
224+
212225
/// @brief Packetize PHI node.
213226
///
214227
/// @param[in] Phi PHI Node to packetize.
@@ -861,12 +874,6 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
861874

862875
auto *const Ins = cast<Instruction>(V);
863876

864-
// FIXME: Add support for vectorizing sub-group shuffles
865-
if (isSubgroupShuffle(Ins)) {
866-
emitVeczRemarkMissed(&F, Ins, "Could not packetize sub-group shuffle");
867-
return Packetizer::Result(*this);
868-
}
869-
870877
if (auto *const Branch = dyn_cast<BranchInst>(Ins)) {
871878
if (Branch->isConditional()) {
872879
// varying reductions need to be packetized
@@ -918,6 +925,19 @@ Packetizer::Result Packetizer::Impl::packetize(Value *V) {
918925
return broadcast(brdcast);
919926
}
920927

928+
if (auto shuffle = isSubgroupShuffleLike(Ins)) {
929+
if (shuffle->Op == compiler::utils::GroupCollective::OpKind::Shuffle) {
930+
if (auto *s = packetizeSubgroupShuffle(Ins)) {
931+
return broadcast(s);
932+
}
933+
}
934+
// We can't packetize all sub-group shuffle-like operations, but we also
935+
// can't vectorize or instantiate them - so provide a diagnostic saying as
936+
// much.
937+
emitVeczRemarkMissed(&F, Ins, "Could not packetize sub-group shuffle");
938+
return Packetizer::Result(*this);
939+
}
940+
921941
// Check if we should broadcast the instruction.
922942
// Broadcast uniform instructions, unless we want to packetize uniform
923943
// instructions as well. We can assume that isMaskVarying is false at this
@@ -1265,18 +1285,133 @@ Value *Packetizer::Impl::packetizeGroupBroadcast(Instruction *I) {
12651285
return CI;
12661286
}
12671287

1268-
bool Packetizer::Impl::isSubgroupShuffle(Instruction *I) {
1288+
std::optional<compiler::utils::GroupCollective>
1289+
Packetizer::Impl::isSubgroupShuffleLike(Instruction *I) {
12691290
auto *const CI = dyn_cast<CallInst>(I);
12701291
if (!CI || !CI->getCalledFunction()) {
1271-
return false;
1292+
return std::nullopt;
12721293
}
12731294
compiler::utils::BuiltinInfo &BI = Ctx.builtins();
12741295
Function *callee = CI->getCalledFunction();
12751296

12761297
auto const Builtin = BI.analyzeBuiltin(*callee);
12771298
auto const Info = BI.isMuxGroupCollective(Builtin.ID);
12781299

1279-
return Info && Info->isSubGroupScope() && Info->isShuffleLike();
1300+
if (Info && Info->isSubGroupScope() && Info->isShuffleLike()) {
1301+
return Info;
1302+
}
1303+
1304+
return std::nullopt;
1305+
}
1306+
1307+
Value *Packetizer::Impl::packetizeSubgroupShuffle(Instruction *I) {
1308+
auto *const CI = cast<CallInst>(I);
1309+
1310+
// We don't support scalable vectorization of sub-group shuffles.
1311+
if (SimdWidth.isScalable()) {
1312+
return nullptr;
1313+
}
1314+
1315+
auto *const Data = CI->getArgOperand(0);
1316+
auto *const Idx = CI->getArgOperand(1);
1317+
1318+
auto PackData = packetize(Data);
1319+
if (!PackData) {
1320+
return nullptr;
1321+
}
1322+
1323+
// If the data operand happened to be a broadcast value already, we can use
1324+
// it directly.
1325+
if (PackData.info->numInstances == 0) {
1326+
IC.deleteInstructionLater(CI);
1327+
CI->replaceAllUsesWith(Data);
1328+
return Data;
1329+
}
1330+
1331+
// We can't packetize varying shuffle indices yet.
1332+
if (UVR.isVarying(Idx)) {
1333+
return nullptr;
1334+
}
1335+
1336+
IRBuilder<> B(CI);
1337+
1338+
// We need to sanitize the input index so that it stays within the range of
1339+
// one vectorized group.
1340+
unsigned const VF = SimdWidth.getFixedValue();
1341+
auto *const VecIdxFactor = ConstantInt::get(Idx->getType(), VF);
1342+
// This index is the element of the vector-group which holds the desired
1343+
// data, per mux sub-group.
1344+
// <x, y>, <z, w>: idx 1 -> vector element 1, idx 2 -> vector element 0.
1345+
auto *const VecIdx = B.CreateURem(Idx, VecIdxFactor);
1346+
// This index is the mux sub-group in which the desired data resides.
1347+
// <x, y>, <z, w>: idx 1 -> mux sub-group 0, idx 3 -> mux sub-group 1.
1348+
auto *const MuxIdx = B.CreateUDiv(Idx, VecIdxFactor);
1349+
1350+
Value *VecData = PackData.getAsValue();
1351+
1352+
// Note: in each illustrative example, imagine two invocations across a
1353+
// single mux sub-groups, each being vectorized by 4; in other words, 8
1354+
// 'original' invocations to a sub-group, running in two vectorized
1355+
// invocations.
1356+
if (auto *const DataVecTy = dyn_cast<VectorType>(Data->getType());
1357+
!DataVecTy) {
1358+
// The vectorized shuffle is producing a scalar (assuming uniform indices,
1359+
// see above). Imagine i=6 (6 % 4 = 2 and 6 / 4 = 1):
1360+
// | shuffle(X, 6) | shuffle(A, 6) |
1361+
// VF=4 |-----------------|-----------------|
1362+
// | s(<X,Y,Z,W>, 2) | s(<A,B,C,D>, 2) |
1363+
// elt 2 | Z | C |
1364+
// shuff | shuffle(Z, 1) | shuffle(C, 1) |
1365+
// | C | C |
1366+
// bcast | <C,C,C,C> | <C,C,C,C> |
1367+
// This way we can see how each of the 8 invocations end up with the 6th
1368+
// element of the total sub-group.
1369+
VecData = B.CreateExtractElement(VecData, VecIdx, "vec.extract");
1370+
} else if (auto *const CIdx = dyn_cast<ConstantInt>(VecIdx)) {
1371+
// The shuffle produces a vector, and we have a constant shuffle index - we
1372+
// can extract a subvector easily.
1373+
// Imagine i=6 (6 % 4 = 2 and 6 / 4 = 1):
1374+
// | shuffle(<X,Y>, 6) | shuffle(<A,B>, 6) |
1375+
// VF=4 |-------------------------|-------------------------|
1376+
// | s(<X,Y,Z,W,P,Q,-,->, 2) | s(<A,B,C,D,E,F,-,->, 2) |
1377+
// vec 2 | <P,Q> | <E,F> |
1378+
// shuff | shuffle(<P,Q>, 1) | shuffle(<E,F>, 1) |
1379+
// | <E,F> | <E,F> |
1380+
// bcast | <E,F,E,F,E,F,E,F> | <E,F,E,F,E,F,E,F> |
1381+
// This way we can see how each of the 8 invocations end up with the 6th
1382+
// element of the total sub-group, which is a two-element vector.
1383+
1384+
// Note: the subvector vector index type has to be i64. Scale it up by the
1385+
// size of the vector we're extracting: the index is the *element* from
1386+
// which to extract - it is not implicitly scaled by the vector size.
1387+
auto *const ExtractIdx = B.getInt64(
1388+
CIdx->getZExtValue() * DataVecTy->getElementCount().getFixedValue());
1389+
VecData = B.CreateExtractVector(Data->getType(), VecData, ExtractIdx,
1390+
"vec.extract");
1391+
} else {
1392+
// This is as above, but the process of extracting the initial vector is
1393+
// more complicated - we have to manually extract and insert each element.
1394+
// It's possible that for some targets and for some combinations of vector
1395+
// width and vectorization factor, that going through memory would be
1396+
// faster.
1397+
Value *ExtractedVec = UndefValue::get(DataVecTy);
1398+
unsigned const DataNumElts = DataVecTy->getElementCount().getFixedValue();
1399+
auto *const BaseIdx = B.CreateMul(VecIdx, B.getInt32(DataNumElts));
1400+
for (unsigned i = 0; i < DataNumElts; i++) {
1401+
auto *const SubIdx = B.CreateAdd(BaseIdx, B.getInt32(i));
1402+
auto *const Elt = B.CreateExtractElement(VecData, SubIdx);
1403+
ExtractedVec = B.CreateInsertElement(ExtractedVec, Elt, B.getInt32(i));
1404+
}
1405+
VecData = ExtractedVec;
1406+
}
1407+
1408+
// We leave the original shuffle function and divert the vectorized
1409+
// shuffle through it, giving us a shuffle over the full apparent
1410+
// sub-group size (vecz * mux).
1411+
CI->setOperand(0, VecData);
1412+
CI->setOperand(1, MuxIdx);
1413+
1414+
return CI;
12801415
}
12811416

12821417
Value *Packetizer::Impl::packetizeMaskVarying(Instruction *I) {

0 commit comments

Comments
 (0)