Skip to content

Commit

Permalink
Simd CF region pass refactor
Browse files Browse the repository at this point in the history
Add test options:
  "simdcf-rm-loop-mask" - replace selects for use em-mask in loops
  "simdcf-skip-search-preds" - for apply at mostly suitable patterns
Add removeMask function for LoopRegions
  • Loading branch information
igorban-intel authored and igcbot committed Jul 10, 2023
1 parent 2868fb7 commit d733546
Show file tree
Hide file tree
Showing 6 changed files with 1,139 additions and 712 deletions.
41 changes: 33 additions & 8 deletions IGC/VectorCompiler/lib/GenXCodeGen/GenXSimdCFConformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ void GenXSimdCFConformance::gatherGotoJoinEMVals(bool IncludeIncoming) {
* gatherEMVals : gather the EM values, including phi nodes
*/
void GenXSimdCFConformance::gatherEMVals() {
LLVM_DEBUG(dbgs() << "gatherEMVals: start\n");
// Collect gotos/joins and their defs
gatherGotoJoinEMVals(true);

Expand Down Expand Up @@ -857,6 +858,7 @@ void GenXSimdCFConformance::gatherEMVals() {
* gatherRMVals : gather RM values for each join
*/
void GenXSimdCFConformance::gatherRMVals() {
LLVM_DEBUG(dbgs() << "gatherRMVals: start\n");
for (auto ji = EMVals.begin(), je = EMVals.end(); ji != je; ++ji) {
auto EMVal = *ji;
auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
Expand Down Expand Up @@ -1287,7 +1289,7 @@ void GenXSimdCFConformance::handleCondValue(Value *GotoJoin) {
*/
void GenXSimdCFConformance::splitGotoJoinBlocks() {

LLVM_DEBUG(dbgs() << "Splitting GotoJoin Blocks\n");
LLVM_DEBUG(dbgs() << "splitGotoJoinBlocks: start\n");

for (auto &Elem : GotoJoinEVsMap) {

Expand Down Expand Up @@ -1464,6 +1466,7 @@ bool GenXSimdCFConformance::hoistGotoUser(Instruction *Inst, CallInst *Goto,
* the goto is not conformant.
*/
void GenXSimdCFConformance::moveCodeInGotoBlocks(bool hoistGotoUsers) {
LLVM_DEBUG(dbgs() << "moveCodeInGotoBlocks: start\n");
for (auto gi = EMVals.begin(), ge = EMVals.end(); gi != ge; ++gi) {
auto EMVal = *gi;
auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
Expand Down Expand Up @@ -1565,6 +1568,7 @@ void GenXSimdCFConformance::moveCodeInGotoBlocks(bool hoistGotoUsers) {
* unconditional branch goto for isBranchingGotoJoinBlock to work.
*/
void GenXSimdCFConformance::moveCodeInJoinBlocks() {
LLVM_DEBUG(dbgs() << "moveCodeInJoinBlocks: start\n");
// a. Handle case 3 join blocks.
if (!FG) {
// Early pass: iterate all funcs in the module.
Expand Down Expand Up @@ -1849,6 +1853,7 @@ bool GenXSimdCFConformance::hoistJoin(CallInst *Join) {
* gotos and joins
*/
void GenXSimdCFConformance::ensureConformance() {
LLVM_DEBUG(dbgs() << "ensureConformance: start\n");
// Push all EM values onto the stack for checking. Push the joins last, since
// we want to process those before their corresponding gotos, so that
// GotoJoinMap is set for a goto by the time we process a valid goto.
Expand Down Expand Up @@ -1972,7 +1977,7 @@ Value *GenXSimdCFConformance::getEMProducer(Value *User,
if (It != EMProducers.end()) {
LLVM_DEBUG(if (It->second) dbgs() << "Using previously found value:\n"
<< *It->second << "\n";
else dbgs() << "Using previously found empty-value!\n";);
else dbgs() << "Using previously found empty-value!\n");
return It->second;
}

Expand Down Expand Up @@ -2270,7 +2275,7 @@ bool GenXSimdCFConformance::checkJoin(SimpleValue EMVal) {
<< "#" << RM.getIndex() << "\n";
for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e;
++i) dbgs()
<< " " << i->getValue()->getName() << "#" << i->getIndex() << "\n";);
<< " " << i->getValue()->getName() << "#" << i->getIndex() << "\n");
if (!Ok) {
LLVM_DEBUG(dbgs() << "checkJoin: illegal RM value in web\n");
return false;
Expand Down Expand Up @@ -2447,16 +2452,21 @@ void GenXSimdCFConformance::pushValues(Value *V) {
* lowered to rdpredregion.
*/
static bool checkAllUsesAreSelectOrWrRegion(Value *V) {
LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: start\n");
for (auto ui2 = V->use_begin(); ui2 != V->use_end(); /*empty*/) {
auto User2 = cast<Instruction>(ui2->getUser());
unsigned OpNum = ui2->getOperandNo();
++ui2;
LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: for user ";
User2->dump());

if (isa<SelectInst>(User2))
continue;

// Matches uses that can be turned into select.
if (auto BI = dyn_cast<BinaryOperator>(User2)) {
LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: binary op\n");

auto Opc = BI->getOpcode();
Constant *AllOne = Constant::getAllOnesValue(V->getType());
Constant *AllNul = Constant::getNullValue(V->getType());
Expand Down Expand Up @@ -2508,7 +2518,9 @@ static bool checkAllUsesAreSelectOrWrRegion(Value *V) {
BI->eraseFromParent();
continue;
}
LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: failed\n");
} else if (auto CI = dyn_cast<CastInst>(User2)) {
LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: cast inst\n");
// Turn zext/sext to select.
if (CI->getOpcode() == Instruction::CastOps::ZExt ||
CI->getOpcode() == Instruction::CastOps::SExt) {
Expand All @@ -2526,6 +2538,7 @@ static bool checkAllUsesAreSelectOrWrRegion(Value *V) {
CI->eraseFromParent();
continue;
}
LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: failed\n");
}

unsigned IID = vc::getAnyIntrinsicID(User2);
Expand All @@ -2537,6 +2550,9 @@ static bool checkAllUsesAreSelectOrWrRegion(Value *V) {
if (vc::isAnyNonTrivialIntrinsic(IID) &&
!cast<CallInst>(User2)->doesNotAccessMemory())
continue;
LLVM_DEBUG(
dbgs() << "checkAllUsesAreSelectOrWrRegion: not found pattern!\n");

return false;
}
return true;
Expand Down Expand Up @@ -2835,8 +2851,8 @@ bool GenXSimdCFConformance::getConnectedVals(
// are select or wrregion.
if (!checkAllUsesAreSelectOrWrRegion(SVI)) {
UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
LLVM_DEBUG(dbgs() << "getConnectedVals: all uses push_back " << *User
<< " No=" << ui->getOperandNo() << "\n");
LLVM_DEBUG(dbgs() << "getConnectedVals: UsersToLower push_back "
<< *User << " No=" << ui->getOperandNo() << "\n");
continue;
}
// Shufflevector produces EM for value baled inst, so this is a (almost)
Expand All @@ -2863,6 +2879,8 @@ bool GenXSimdCFConformance::getConnectedVals(
case GenXIntrinsic::genx_simdcf_goto:
LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_simdcf_goto\n");
// use in goto: valid only if arg 0 (EM) or 1 (RM)
LLVM_DEBUG(dbgs() << "with operand no = " << ui->getOperandNo()
<< "\n");
if (ui->getOperandNo() != (Cat == vc::RegCategory::EM ? 0U : 1U))
return false;
// Add corresponding result.
Expand Down Expand Up @@ -2894,8 +2912,8 @@ bool GenXSimdCFConformance::getConnectedVals(
// wrregion.
if (!checkAllUsesAreSelectOrWrRegion(CI)) {
UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
LLVM_DEBUG(dbgs() << "getConnectedVals: all uses push_back " << *CI
<< " No=" << ui->getOperandNo() << "\n");
LLVM_DEBUG(dbgs() << "getConnectedVals: UsersToLower push_back "
<< *CI << " No=" << ui->getOperandNo() << "\n");
}
break;
case GenXIntrinsic::genx_wrpredpredregion:
Expand Down Expand Up @@ -2978,8 +2996,15 @@ bool GenXSimdCFConformance::getConnectedVals(
removeFromEMRMVals(Inst);
}
} else {
if (!UsersToLower.empty())
if (!UsersToLower.empty()) {
LLVM_DEBUG(dbgs() << "getConnectedVals: find bad users:\n";
for (auto &BadUser
: UsersToLower) {
dbgs() << " ";
BadUser.dump();
});
return false;
}
}

return true;
Expand Down
Loading

0 comments on commit d733546

Please sign in to comment.