@@ -13,6 +13,7 @@ SPDX-License-Identifier: MIT
1313#include " common/LLVMWarningsPush.hpp"
1414#include < llvm/ADT/DenseMap.h>
1515#include < llvm/ADT/SmallVector.h>
16+ #include < llvm/IR/Dominators.h>
1617#include " common/LLVMWarningsPop.hpp"
1718
1819#define DEBUG_TYPE " igc-wave-shuffle-index-sinking"
@@ -207,19 +208,30 @@ namespace IGC
207208 return numProfitableHoistable;
208209 }
209210
210- void hoist () {
211+ bool hoist (DenseMap<BasicBlock*, SmallVector<Instruction*, 4 >>& MoveToCommonDominatorInstMap, DominatorTree& DT ) {
212+ // If there is no common dominator abort hoisting
213+ BasicBlock* CommonDominator = findCommonDominator ( DT );
214+ if ( !CommonDominator ) return false ;
215+
211216 // Track the new source for all the ShuffleOps
212217 auto * prev = ShuffleOps.front ()->getSrc ();
213218
214219 for ( unsigned idx = 0 ; idx < HoistOrAnchorInstsIdx.size (); idx++ )
215220 {
221+ bool moveToCommonDominator = false ;
216222 if ( HoistOrAnchorInstsIdx[ idx ] )
217223 {
218224 // clone the inst to be hoisted
219225 auto * hoistedInst = InstChains.front ()[ idx ]->clone ();
220226 hoistedInst->setName ( InstChains.front ()[ idx ]->getName () + " _hoisted" );
221227 hoistedInst->insertBefore ( ShuffleOps.front () );
222228
229+ if ( CommonDominator != hoistedInst->getParent () )
230+ {
231+ moveToCommonDominator = true ;
232+ MoveToCommonDominatorInstMap[ CommonDominator ].emplace_back ( hoistedInst );
233+ }
234+
223235 // Replace the correct operand
224236 auto * hoistedOp0 = hoistedInst->getOperand ( 0 );
225237 Instruction* hoistedOpPrev = ( idx == 0 ) ? cast<Instruction>( ShuffleOps.front () ) : InstChains.front ()[ idx - 1 ];
@@ -288,6 +300,12 @@ namespace IGC
288300 InstChains[ i ][ anchorIdx ]->setOperand ( 0 , anchorHoistedInst );
289301 }
290302 }
303+
304+ // If hoisted instruction is moved, it's safe to move anchor as well.
305+ if ( moveToCommonDominator )
306+ {
307+ MoveToCommonDominatorInstMap[ CommonDominator ].emplace_back ( anchorHoistedInst );
308+ }
291309 }
292310 }
293311 }
@@ -345,10 +363,24 @@ namespace IGC
345363 instChain[ i ]->eraseFromParent ();
346364 }
347365 }
366+
367+ return true ;
348368 }
349369
350370 SmallVector<WaveShuffleIndexIntrinsic*> ShuffleOps; // all the WaveShuffleIndex instructions in the group
351371 private:
372+ BasicBlock* findCommonDominator ( DominatorTree& DT )
373+ {
374+ BasicBlock* DomBB = ShuffleOps.front ()->getParent ();
375+ for ( auto & inst : ShuffleOps )
376+ {
377+ BasicBlock* UseBB = inst->getParent ();
378+ DomBB = DT.findNearestCommonDominator ( DomBB, UseBB );
379+ }
380+
381+ return DomBB;
382+ }
383+
352384 SmallVector<SmallVector<BinaryOperator*>> InstChains; // all common instructions shared by the shuffle ops, some can be hoisted
353385 SmallVector<bool > HoistOrAnchorInstsIdx; // Type of each Binary Operator in each InstChain: true - Hoistable/Hoistable past previous Anchors, false - Anchor
354386 }; // ShuffleGroup
@@ -359,13 +391,16 @@ namespace IGC
359391 private:
360392 bool splitWaveShuffleIndexes ();
361393 bool mergeWaveShuffleIndexes ();
394+ bool moveToCommonDominator ();
362395 void gatherShuffleGroups ();
363396 bool sinkShuffleGroups ();
364397 static unsigned compareWaveShuffleIndexes ( WaveShuffleIndexIntrinsic* waveShuffleIndex, WaveShuffleIndexIntrinsic* newWaveShuffleIndex, SmallVector<BinaryOperator*>& InstChain, SmallVector<BinaryOperator*>& newInstChain, SmallVector<bool >& hoistOrAnchor );
365398 static bool isHoistable ( BinaryOperator* inst );
366399 static bool isHoistableOverAnchor ( BinaryOperator* instToHoist, BinaryOperator* anchorInst );
367400 Function& F;
368- DenseMap<std::pair<BasicBlock*, Value*>, SmallVector<ShuffleGroup, 4 >> ShuffleGroupMap;
401+ DominatorTree DT;
402+ DenseMap<BasicBlock*, SmallVector<Instruction*, 4 >> MoveToCommonDominatorInstMap;
403+ DenseMap<Value*, SmallVector<ShuffleGroup, 4 >> ShuffleGroupMap;
369404 DenseSet<WaveShuffleIndexIntrinsic*> Visited;
370405 };
371406
@@ -445,11 +480,30 @@ bool WaveShuffleIndexSinkingImpl::splitWaveShuffleIndexes()
445480 return Changed;
446481}
447482
483+ bool WaveShuffleIndexSinkingImpl::moveToCommonDominator ()
484+ {
485+ // hoisted intruction needs to be moved to common dominator BB.
486+ // If instructions in shuffle group are from different basic blocks
487+ // there is a risk of non-dominating all users.
488+ bool Changed = false ;
489+ for ( auto & bb : MoveToCommonDominatorInstMap )
490+ {
491+ auto instrInsertPtr = ( &*bb.first ->getFirstInsertionPt () );
492+ for ( auto & inst : bb.second )
493+ {
494+ inst->moveBefore ( instrInsertPtr );
495+ Changed = true ;
496+ }
497+ }
498+
499+ return Changed;
500+ }
501+
448502// Merge WaveShuffleIndex instructions that have the same source operand and the same constant lane/channel operand
449503bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes ()
450504{
451505 // Map from Source to (Map from Lane to list of duplicate instructions)
452- DenseMap<std::pair<BasicBlock*, Value*> , DenseMap<ConstantInt*, SmallVector<WaveShuffleIndexIntrinsic*>>> mergeMap;
506+ DenseMap<Value*, DenseMap<ConstantInt*, SmallVector<WaveShuffleIndexIntrinsic*>>> mergeMap;
453507 for ( auto & BB : F )
454508 {
455509 for ( auto & I : BB )
@@ -458,7 +512,7 @@ bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes()
458512 {
459513 if ( auto * constantChannel = dyn_cast<ConstantInt>( waveShuffleInst->getChannel () ) )
460514 {
461- mergeMap[ {&BB, waveShuffleInst->getSrc ()} ][ constantChannel ].push_back ( waveShuffleInst );
515+ mergeMap[ waveShuffleInst->getSrc () ][ constantChannel ].push_back ( waveShuffleInst );
462516 }
463517 }
464518 }
@@ -476,12 +530,36 @@ bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes()
476530 Changed = true ;
477531 auto * mainShuffleIndex = duplicateInsts.front ();
478532
533+ // Find common dominator for main WaveShuffleIndex
534+ bool moveToCommonDominator = false ;
535+ BasicBlock* DomBB = mainShuffleIndex->getParent ();
536+
537+ for ( unsigned i = 1 ; i < duplicateInsts.size (); i++ )
538+ {
539+ BasicBlock* UseBB = duplicateInsts[ i ]->getParent ();
540+ DomBB = DT.findNearestCommonDominator ( DomBB, UseBB );
541+ }
542+
543+ if ( !DomBB )
544+ {
545+ // Do not merge if Common Dominator is not found
546+ Changed = false ;
547+ continue ;
548+ }
549+
550+ moveToCommonDominator = DomBB != mainShuffleIndex->getParent () ? true : false ;
551+
479552 // replace uses of other WaveShuffleIndex with the first one
480553 for ( unsigned i = 1 ; i < duplicateInsts.size (); i++ )
481554 {
482555 duplicateInsts[ i ]->replaceAllUsesWith ( mainShuffleIndex );
483556 duplicateInsts[ i ]->eraseFromParent ();
484557 }
558+
559+ if (moveToCommonDominator)
560+ {
561+ MoveToCommonDominatorInstMap[ DomBB ].emplace_back ( mainShuffleIndex );
562+ }
485563 }
486564 }
487565
@@ -503,13 +581,11 @@ void WaveShuffleIndexSinkingImpl::gatherShuffleGroups()
503581 // Save compute and do not re-process/ create a new ShuffleGroup
504582 continue ;
505583 }
506-
507- std::pair<BasicBlock*, Value*> bbShuffleGroup = { &BB, waveShuffleInst->getSrc () };
508- if ( ShuffleGroupMap.count ( bbShuffleGroup ) )
584+ if ( ShuffleGroupMap.count ( waveShuffleInst->getSrc () ) )
509585 {
510586 // Found existing group(s) with the same source, try to match with one of the groups
511587 bool match = false ;
512- for ( auto & shuffleGroup : ShuffleGroupMap[ bbShuffleGroup ] )
588+ for ( auto & shuffleGroup : ShuffleGroupMap[ waveShuffleInst-> getSrc () ] )
513589 {
514590 if ( shuffleGroup.match ( waveShuffleInst ) )
515591 {
@@ -521,13 +597,13 @@ void WaveShuffleIndexSinkingImpl::gatherShuffleGroups()
521597 // create new ShuffleGroup since no suitable match was found
522598 if ( !match )
523599 {
524- ShuffleGroupMap[ {&BB, waveShuffleInst->getSrc ()} ].emplace_back ( waveShuffleInst );
600+ ShuffleGroupMap[ waveShuffleInst->getSrc () ].emplace_back ( waveShuffleInst );
525601 }
526602 }
527603 else
528604 {
529605 // create new ShuffleGroup for broadcast operations
530- ShuffleGroupMap[ {&BB, waveShuffleInst->getSrc ()} ].emplace_back ( waveShuffleInst );
606+ ShuffleGroupMap[ waveShuffleInst->getSrc () ].emplace_back ( waveShuffleInst );
531607 }
532608 }
533609 }
@@ -546,8 +622,7 @@ bool WaveShuffleIndexSinkingImpl::sinkShuffleGroups()
546622 if ( numProfitableToHoist > 0 )
547623 {
548624 // Pre-process found profitable instructions left to hoist
549- shuffleGroup.hoist ();
550- Changed = true ;
625+ Changed |= shuffleGroup.hoist ( MoveToCommonDominatorInstMap, DT );
551626 }
552627 else
553628 {
@@ -699,6 +774,7 @@ bool WaveShuffleIndexSinkingImpl::isHoistableOverAnchor( BinaryOperator* instToH
699774
700775bool WaveShuffleIndexSinkingImpl::run ()
701776{
777+ DT.recalculate (F);
702778 bool Changed = splitWaveShuffleIndexes ();
703779
704780 unsigned numIters = 0 ;
@@ -718,6 +794,7 @@ bool WaveShuffleIndexSinkingImpl::run()
718794 ShuffleGroupMap.clear ();
719795 }
720796 Changed |= mergeWaveShuffleIndexes ();
797+ Changed |= moveToCommonDominator ();
721798 return Changed;
722799}
723800
0 commit comments