@@ -535,6 +535,14 @@ void CMSimdCFLower::processFunction(Function *ArgF)
535535 unsigned CMWidth = PredicatedSubroutines[F];
536536 // Find the simd branches.
537537 bool FoundSIMD = findSimdBranches (CMWidth);
538+
539+ // Create shuffle mask for EM adjustment
540+ if (ShuffleMask.empty ()) {
541+ auto I32Ty = Type::getInt32Ty (F->getContext ());
542+ for (unsigned i = 0 ; i != 32 ; ++i)
543+ ShuffleMask.push_back (ConstantInt::get (I32Ty, i));
544+ }
545+
538546 if (CMWidth > 0 || FoundSIMD) {
539547 // Determine which basic blocks need to be predicated.
540548 determinePredicatedBlocks ();
@@ -555,10 +563,13 @@ void CMSimdCFLower::processFunction(Function *ArgF)
555563 lowerSimdCF ();
556564 lowerUnmaskOps ();
557565 }
566+
567+ ShuffleMask.clear ();
558568 SimdBranches.clear ();
559569 PredicatedBlocks.clear ();
560570 JoinPoints.clear ();
561571 RMAddrs.clear ();
572+ OriginalPred.clear ();
562573 AlreadyPredicated.clear ();
563574}
564575
@@ -1214,6 +1225,7 @@ unsigned CMSimdCFLower::deduceNumChannels(Instruction *SI) {
12141225 // If it's not a function call then check for a specific instruction
12151226 unsigned IID = GenXIntrinsic::getGenXIntrinsicID (CI);
12161227 switch (IID) {
1228+ case GenXIntrinsic::genx_gather4_masked_scaled2:
12171229 case GenXIntrinsic::genx_gather4_scaled2: {
12181230 unsigned AddrElems = VCINTR::VectorType::getNumElements (
12191231 cast<VectorType>(CI->getOperand (4 )->getType ()));
@@ -1262,6 +1274,7 @@ void CMSimdCFLower::predicateStore(Instruction *SI, unsigned SimdWidth)
12621274 CallInst *WrRegionToPredicate = nullptr ;
12631275 Use *U = &SI->getOperandUse (0 );
12641276 Use *UseNeedsUpdate = nullptr ;
1277+ Value *ExistingPred = nullptr ;
12651278 for (;;) {
12661279 if (auto BC = dyn_cast<BitCastInst>(V)) {
12671280 U = &BC->getOperandUse (0 );
@@ -1277,6 +1290,15 @@ void CMSimdCFLower::predicateStore(Instruction *SI, unsigned SimdWidth)
12771290 unsigned IID = GenXIntrinsic::getGenXIntrinsicID (WrRegion);
12781291 if (IID != GenXIntrinsic::genx_wrregioni
12791292 && IID != GenXIntrinsic::genx_wrregionf) {
1293+ // genx_gather4_masked_scaled2 is slightly different: it has predicate
1294+ // operand and its users have to be predicated as well since it returns value
1295+ // with size greater of execution size
1296+ if (IID == GenXIntrinsic::genx_gather4_masked_scaled2) {
1297+ assert (AlreadyPredicated.find (WrRegion) != AlreadyPredicated.end ());
1298+ if (OriginalPred.count (WrRegion))
1299+ ExistingPred = OriginalPred[WrRegion];
1300+ break ;
1301+ }
12801302 // Not wrregion. See if it is an intrinsic that has already been
12811303 // predicated; if so do not attempt to predicate the store.
12821304 if (AlreadyPredicated.find (WrRegion) != AlreadyPredicated.end ())
@@ -1361,7 +1383,19 @@ void CMSimdCFLower::predicateStore(Instruction *SI, unsigned SimdWidth)
13611383 Load = CallInst::Create (Fn, Addr, " .simdcfpred.vload" , SI);
13621384 }
13631385 Load->setDebugLoc (SI->getDebugLoc ());
1364- auto EM = loadExecutionMask (SI, SimdWidth, NumChannels);
1386+ Value *EM = loadExecutionMask (SI, SimdWidth);
1387+
1388+ // If there was a predicate already then update it with current EM
1389+ if (ExistingPred) {
1390+ EM = BinaryOperator::Create (
1391+ Instruction::And, ExistingPred, EM,
1392+ ExistingPred->getName () + " .and." + EM->getName (), SI);
1393+ cast<Instruction>(EM)->setDebugLoc (SI->getDebugLoc ());
1394+ }
1395+
1396+ // Replicate mask for each channel if needed
1397+ EM = replicateMask (EM, SI, SimdWidth, NumChannels);
1398+
13651399 auto Select = SelectInst::Create (EM, SI->getOperand (0 ), Load,
13661400 SI->getOperand (0 )->getName () + " .simdcfpred" , SI);
13671401 SI->setOperand (0 , Select);
@@ -1450,16 +1484,26 @@ void CMSimdCFLower::predicateScatterGather(CallInst *CI, unsigned SimdWidth,
14501484{
14511485 Value *OldPred = CI->getArgOperand (PredOperandNum);
14521486 assert (OldPred->getType ()->getScalarType ()->isIntegerTy (1 ));
1453- if (SimdWidth != VCINTR::VectorType::getNumElements (
1454- cast<VectorType>(OldPred->getType ()))) {
1455- DiagnosticInfoSimdCF::emit (CI, " mismatching SIMD width of scatter/gather inside SIMD control flow" );
1456- return ;
1487+ switch (GenXIntrinsic::getGenXIntrinsicID (CI)) {
1488+ case GenXIntrinsic::genx_gather4_masked_scaled2:
1489+ break ;
1490+ default : {
1491+ if (SimdWidth != VCINTR::VectorType::getNumElements (
1492+ cast<VectorType>(OldPred->getType ()))) {
1493+ DiagnosticInfoSimdCF::emit (
1494+ CI,
1495+ " mismatching SIMD width of scatter/gather inside SIMD control flow" );
1496+ return ;
1497+ }
1498+ break ;
1499+ }
14571500 }
14581501 Instruction *NewPred = loadExecutionMask (CI, SimdWidth);
14591502 if (auto C = dyn_cast<Constant>(OldPred))
14601503 if (C->isAllOnesValue ())
14611504 OldPred = nullptr ;
14621505 if (OldPred) {
1506+ OriginalPred[CI] = OldPred;
14631507 auto And = BinaryOperator::Create (Instruction::And, OldPred, NewPred,
14641508 OldPred->getName () + " .and." + NewPred->getName (), CI);
14651509 And->setDebugLoc (CI->getDebugLoc ());
@@ -1496,6 +1540,7 @@ CallInst *CMSimdCFLower::predicateWrRegion(CallInst *WrR, unsigned SimdWidth)
14961540 if (!Pred)
14971541 Pred = EM;
14981542 else {
1543+ OriginalPred[WrR] = Pred;
14991544 auto And = BinaryOperator::Create (Instruction::And, EM, Pred,
15001545 Pred->getName () + " .and." + EM->getName (), WrR);
15011546 And->setDebugLoc (WrR->getDebugLoc ());
@@ -1783,39 +1828,46 @@ CallInst *CMSimdCFLower::isSimdCFAny(Value *V)
17831828 return nullptr ;
17841829}
17851830
1831+ /* **********************************************************************
1832+ * replicateMask : copy mask for provided number of channels using shufflevector
1833+ */
1834+ Value *CMSimdCFLower::replicateMask (Value *EM, Instruction *InsertBefore,
1835+ unsigned SimdWidth, unsigned NumChannels) {
1836+ // No need to replicate the mask for one channel
1837+ if (NumChannels == 1 )
1838+ return EM;
1839+
1840+ SmallVector<Constant *, 128 > ChannelMask{SimdWidth * NumChannels};
1841+ for (unsigned i = 0 ; i < NumChannels; ++i)
1842+ std::copy (ShuffleMask.begin (), ShuffleMask.begin () + SimdWidth,
1843+ ChannelMask.begin () + SimdWidth * i);
1844+ EM = new ShuffleVectorInst (
1845+ EM, UndefValue::get (EM->getType ()), ConstantVector::get (ChannelMask),
1846+ Twine (" ChannelEM" ) + Twine (SimdWidth), InsertBefore);
1847+
1848+ return EM;
1849+ }
1850+
17861851/* **********************************************************************
17871852 * loadExecutionMask : create instruction to load EM
17881853 */
17891854Instruction *CMSimdCFLower::loadExecutionMask (Instruction *InsertBefore,
1790- unsigned SimdWidth, unsigned NumChannels)
1791- {
1855+ unsigned SimdWidth) {
17921856 Instruction *EM =
17931857 new LoadInst (EMVar->getType ()->getPointerElementType (), EMVar,
17941858 EMVar->getName (), false /* isVolatile */ , InsertBefore);
1795- EM-> setDebugLoc (InsertBefore-> getDebugLoc ());
1859+
17961860 // If the simd width is not MAX_SIMD_CF_WIDTH, extract the part of EM we want.
1797- if (NumChannels == 1 && SimdWidth == MAX_SIMD_CF_WIDTH)
1861+ if (SimdWidth == MAX_SIMD_CF_WIDTH)
17981862 return EM;
1799- if (ShuffleMask.empty ()) {
1800- auto I32Ty = Type::getInt32Ty (F->getContext ());
1801- for (unsigned i = 0 ; i != 32 ; ++i)
1802- ShuffleMask.push_back (ConstantInt::get (I32Ty, i));
1803- }
1804- if (NumChannels == 1 ) {
1805- ArrayRef<Constant *> Mask = ShuffleMask;
1806- EM = new ShuffleVectorInst (EM, UndefValue::get (EM->getType ()),
1807- ConstantVector::get (Mask.take_front (SimdWidth)),
1808- Twine (" EM" ) + Twine (SimdWidth), InsertBefore);
1809- } else {
1810- SmallVector<Constant *, 128 > ChannelMask{SimdWidth * NumChannels};
1811- for (unsigned i = 0 ; i < NumChannels; ++i)
1812- std::copy (ShuffleMask.begin (), ShuffleMask.begin () + SimdWidth,
1813- ChannelMask.begin () + SimdWidth * i);
1814- EM = new ShuffleVectorInst (
1815- EM, UndefValue::get (EM->getType ()), ConstantVector::get (ChannelMask),
1816- Twine (" ChannelEM" ) + Twine (SimdWidth), InsertBefore);
1817- }
1863+
1864+ ArrayRef<Constant *> Mask = ShuffleMask;
1865+ EM = new ShuffleVectorInst (EM, UndefValue::get (EM->getType ()),
1866+ ConstantVector::get (Mask.take_front (SimdWidth)),
1867+ Twine (" EM" ) + Twine (SimdWidth), InsertBefore);
1868+
18181869 EM->setDebugLoc (InsertBefore->getDebugLoc ());
1870+
18191871 return EM;
18201872}
18211873
0 commit comments