Skip to content

Commit a459d3d

Browse files
MiloszSkobejkoigcbot
authored andcommitted
Fix WaveShuffleIndexSinking Pass to handle control flow
Domination tree is used to check if new created instructions for hoisted variables are placed in dominating basic block and if not - they are moved to nearest found one. The same mechanism is implemented for merging function - if "main" WaveShuffleIndex call is not in dominating bb it's moved to it.
1 parent 31fe1d1 commit a459d3d

File tree

4 files changed

+173
-58
lines changed

4 files changed

+173
-58
lines changed

IGC/Compiler/Optimizer/WaveShuffleIndexSinking.cpp

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ SPDX-License-Identifier: MIT
1313
#include "common/LLVMWarningsPush.hpp"
1414
#include <llvm/ADT/DenseMap.h>
1515
#include <llvm/ADT/SmallVector.h>
16+
#include <llvm/IR/Dominators.h>
1617
#include "common/LLVMWarningsPop.hpp"
1718

1819
#define DEBUG_TYPE "igc-wave-shuffle-index-sinking"
@@ -207,19 +208,30 @@ namespace IGC
207208
return numProfitableHoistable;
208209
}
209210

210-
void hoist() {
211+
bool hoist(DenseMap<BasicBlock*, SmallVector<Instruction*, 4>>& MoveToCommonDominatorInstMap, DominatorTree& DT ) {
212+
// If there is no common dominator abort hoisting
213+
BasicBlock* CommonDominator = findCommonDominator( DT );
214+
if ( !CommonDominator ) return false;
215+
211216
// Track the new source for all the ShuffleOps
212217
auto* prev = ShuffleOps.front()->getSrc();
213218

214219
for( unsigned idx = 0; idx < HoistOrAnchorInstsIdx.size(); idx++ )
215220
{
221+
bool moveToCommonDominator = false;
216222
if( HoistOrAnchorInstsIdx[ idx ] )
217223
{
218224
// clone the inst to be hoisted
219225
auto* hoistedInst = InstChains.front()[ idx ]->clone();
220226
hoistedInst->setName( InstChains.front()[ idx ]->getName() + "_hoisted" );
221227
hoistedInst->insertBefore( ShuffleOps.front() );
222228

229+
if ( CommonDominator != hoistedInst->getParent() )
230+
{
231+
moveToCommonDominator = true;
232+
MoveToCommonDominatorInstMap[ CommonDominator ].emplace_back( hoistedInst );
233+
}
234+
223235
// Replace the correct operand
224236
auto* hoistedOp0 = hoistedInst->getOperand( 0 );
225237
Instruction* hoistedOpPrev = ( idx == 0 ) ? cast<Instruction>( ShuffleOps.front() ) : InstChains.front()[ idx - 1 ];
@@ -288,6 +300,12 @@ namespace IGC
288300
InstChains[ i ][ anchorIdx ]->setOperand( 0, anchorHoistedInst );
289301
}
290302
}
303+
304+
// If hoisted instruction is moved, it's safe to move anchor as well.
305+
if ( moveToCommonDominator )
306+
{
307+
MoveToCommonDominatorInstMap[ CommonDominator ].emplace_back( anchorHoistedInst );
308+
}
291309
}
292310
}
293311
}
@@ -345,10 +363,24 @@ namespace IGC
345363
instChain[ i ]->eraseFromParent();
346364
}
347365
}
366+
367+
return true;
348368
}
349369

350370
SmallVector<WaveShuffleIndexIntrinsic*> ShuffleOps; // all the WaveShuffleIndex instructions in the group
351371
private:
372+
BasicBlock* findCommonDominator( DominatorTree& DT )
373+
{
374+
BasicBlock* DomBB = ShuffleOps.front()->getParent();
375+
for ( auto& inst : ShuffleOps )
376+
{
377+
BasicBlock* UseBB = inst->getParent();
378+
DomBB = DT.findNearestCommonDominator( DomBB, UseBB );
379+
}
380+
381+
return DomBB;
382+
}
383+
352384
SmallVector<SmallVector<BinaryOperator*>> InstChains; // all common instructions shared by the shuffle ops, some can be hoisted
353385
SmallVector<bool> HoistOrAnchorInstsIdx; // Type of each Binary Operator in each InstChain: true - Hoistable/Hoistable past previous Anchors, false - Anchor
354386
}; //ShuffleGroup
@@ -359,13 +391,16 @@ namespace IGC
359391
private:
360392
bool splitWaveShuffleIndexes();
361393
bool mergeWaveShuffleIndexes();
394+
bool moveToCommonDominator();
362395
void gatherShuffleGroups();
363396
bool sinkShuffleGroups();
364397
static unsigned compareWaveShuffleIndexes( WaveShuffleIndexIntrinsic* waveShuffleIndex, WaveShuffleIndexIntrinsic* newWaveShuffleIndex, SmallVector<BinaryOperator*>& InstChain, SmallVector<BinaryOperator*>& newInstChain, SmallVector<bool>& hoistOrAnchor );
365398
static bool isHoistable( BinaryOperator* inst );
366399
static bool isHoistableOverAnchor( BinaryOperator* instToHoist, BinaryOperator* anchorInst );
367400
Function& F;
368-
DenseMap<std::pair<BasicBlock*, Value*>, SmallVector<ShuffleGroup, 4>> ShuffleGroupMap;
401+
DominatorTree DT;
402+
DenseMap<BasicBlock*, SmallVector<Instruction*, 4>> MoveToCommonDominatorInstMap;
403+
DenseMap<Value*, SmallVector<ShuffleGroup, 4>> ShuffleGroupMap;
369404
DenseSet<WaveShuffleIndexIntrinsic*> Visited;
370405
};
371406

@@ -445,11 +480,30 @@ bool WaveShuffleIndexSinkingImpl::splitWaveShuffleIndexes()
445480
return Changed;
446481
}
447482

483+
bool WaveShuffleIndexSinkingImpl::moveToCommonDominator()
484+
{
485+
// hoisted intruction needs to be moved to common dominator BB.
486+
// If instructions in shuffle group are from different basic blocks
487+
// there is a risk of non-dominating all users.
488+
bool Changed = false;
489+
for ( auto& bb : MoveToCommonDominatorInstMap )
490+
{
491+
auto instrInsertPtr = ( &*bb.first->getFirstInsertionPt() );
492+
for ( auto& inst : bb.second )
493+
{
494+
inst->moveBefore( instrInsertPtr );
495+
Changed = true;
496+
}
497+
}
498+
499+
return Changed;
500+
}
501+
448502
// Merge WaveShuffleIndex instructions that have the same source operand and the same constant lane/channel operand
449503
bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes()
450504
{
451505
// Map from Source to (Map from Lane to list of duplicate instructions)
452-
DenseMap<std::pair<BasicBlock*, Value*>, DenseMap<ConstantInt*, SmallVector<WaveShuffleIndexIntrinsic*>>> mergeMap;
506+
DenseMap<Value*, DenseMap<ConstantInt*, SmallVector<WaveShuffleIndexIntrinsic*>>> mergeMap;
453507
for( auto& BB : F )
454508
{
455509
for( auto& I : BB )
@@ -458,7 +512,7 @@ bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes()
458512
{
459513
if( auto* constantChannel = dyn_cast<ConstantInt>( waveShuffleInst->getChannel() ) )
460514
{
461-
mergeMap[ {&BB, waveShuffleInst->getSrc()} ][ constantChannel ].push_back( waveShuffleInst );
515+
mergeMap[ waveShuffleInst->getSrc() ][ constantChannel ].push_back( waveShuffleInst );
462516
}
463517
}
464518
}
@@ -476,12 +530,36 @@ bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes()
476530
Changed = true;
477531
auto* mainShuffleIndex = duplicateInsts.front();
478532

533+
// Find common dominator for main WaveShuffleIndex
534+
bool moveToCommonDominator = false;
535+
BasicBlock* DomBB = mainShuffleIndex->getParent();
536+
537+
for ( unsigned i = 1; i < duplicateInsts.size(); i++ )
538+
{
539+
BasicBlock* UseBB = duplicateInsts[ i ]->getParent();
540+
DomBB = DT.findNearestCommonDominator( DomBB, UseBB );
541+
}
542+
543+
if ( !DomBB )
544+
{
545+
// Do not merge if Common Dominator is not found
546+
Changed = false;
547+
continue;
548+
}
549+
550+
moveToCommonDominator = DomBB != mainShuffleIndex->getParent() ? true : false;
551+
479552
// replace uses of other WaveShuffleIndex with the first one
480553
for( unsigned i = 1; i < duplicateInsts.size(); i++ )
481554
{
482555
duplicateInsts[ i ]->replaceAllUsesWith( mainShuffleIndex );
483556
duplicateInsts[ i ]->eraseFromParent();
484557
}
558+
559+
if (moveToCommonDominator)
560+
{
561+
MoveToCommonDominatorInstMap[ DomBB ].emplace_back( mainShuffleIndex );
562+
}
485563
}
486564
}
487565

@@ -503,13 +581,11 @@ void WaveShuffleIndexSinkingImpl::gatherShuffleGroups()
503581
// Save compute and do not re-process/ create a new ShuffleGroup
504582
continue;
505583
}
506-
507-
std::pair<BasicBlock*, Value*> bbShuffleGroup = { &BB, waveShuffleInst->getSrc() };
508-
if ( ShuffleGroupMap.count( bbShuffleGroup ) )
584+
if( ShuffleGroupMap.count( waveShuffleInst->getSrc() ) )
509585
{
510586
// Found existing group(s) with the same source, try to match with one of the groups
511587
bool match = false;
512-
for (auto& shuffleGroup : ShuffleGroupMap[ bbShuffleGroup ] )
588+
for( auto& shuffleGroup : ShuffleGroupMap[ waveShuffleInst->getSrc() ] )
513589
{
514590
if( shuffleGroup.match( waveShuffleInst ) )
515591
{
@@ -521,13 +597,13 @@ void WaveShuffleIndexSinkingImpl::gatherShuffleGroups()
521597
// create new ShuffleGroup since no suitable match was found
522598
if( !match )
523599
{
524-
ShuffleGroupMap[ {&BB, waveShuffleInst->getSrc()} ].emplace_back( waveShuffleInst );
600+
ShuffleGroupMap[ waveShuffleInst->getSrc() ].emplace_back( waveShuffleInst );
525601
}
526602
}
527603
else
528604
{
529605
// create new ShuffleGroup for broadcast operations
530-
ShuffleGroupMap[ {&BB, waveShuffleInst->getSrc()} ].emplace_back( waveShuffleInst );
606+
ShuffleGroupMap[ waveShuffleInst->getSrc() ].emplace_back( waveShuffleInst );
531607
}
532608
}
533609
}
@@ -546,8 +622,7 @@ bool WaveShuffleIndexSinkingImpl::sinkShuffleGroups()
546622
if( numProfitableToHoist > 0 )
547623
{
548624
// Pre-process found profitable instructions left to hoist
549-
shuffleGroup.hoist();
550-
Changed = true;
625+
Changed |= shuffleGroup.hoist( MoveToCommonDominatorInstMap, DT );
551626
}
552627
else
553628
{
@@ -699,6 +774,7 @@ bool WaveShuffleIndexSinkingImpl::isHoistableOverAnchor( BinaryOperator* instToH
699774

700775
bool WaveShuffleIndexSinkingImpl::run()
701776
{
777+
DT.recalculate(F);
702778
bool Changed = splitWaveShuffleIndexes();
703779

704780
unsigned numIters = 0;
@@ -718,6 +794,7 @@ bool WaveShuffleIndexSinkingImpl::run()
718794
ShuffleGroupMap.clear();
719795
}
720796
Changed |= mergeWaveShuffleIndexes();
797+
Changed |= moveToCommonDominator();
721798
return Changed;
722799
}
723800

IGC/Compiler/tests/WaveShuffleIndexSinking/constant-src-different-bb.ll

Lines changed: 0 additions & 46 deletions
This file was deleted.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2025 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
; RUN: igc_opt --typed-pointers -igc-wave-shuffle-index-sinking -S < %s | FileCheck %s
9+
; ------------------------------------------------
10+
; WaveShuffleIndexSinking
11+
;
12+
; Verifies if pass correctly moves hoisted variables and merged WaveShuffleIndex
13+
; to common dominator basic block in correct order. In this case it is bb1.
14+
15+
define void @test(i32* %dst0, i32* %dst1, i32 %a, i32 %b) {
16+
entry:
17+
br label %bb1
18+
19+
bb0:
20+
; CHECK-LABEL: bb0:
21+
; CHECK-NOT [[WS:%.*]] = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 [[TMP:.*]], i32 0, i32 0)
22+
%ws0 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %a, i32 0, i32 0)
23+
%add0 = add i32 %b, %ws0
24+
%shl0 = shl i32 %add0, 2
25+
store i32 %shl0, i32* %dst0
26+
br label %exit
27+
28+
bb1:
29+
; CHECK-LABEL: bb1:
30+
; CHECK-NEXT: [[HOISTED:%.*]] = shl i32 %a, 2
31+
; CHECK-NEXT: [[ANCHOR_HOISTED:%.*]] = shl i32 %b, 2
32+
; CHECK-NEXT: [[WS0:%.*]] = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 [[HOISTED]], i32 0, i32 0)
33+
%ws1 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %a, i32 0, i32 0)
34+
%add1 = add i32 %ws1, %b
35+
%shl1 = shl i32 %add1, 2
36+
store i32 %shl1, i32* %dst1
37+
br label %bb0
38+
39+
exit:
40+
ret void
41+
}
42+
43+
; Function Attrs: convergent nounwind readnone
44+
declare i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32, i32, i32) #0
45+
46+
attributes #0 = { convergent nounwind readnone }
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2025 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
; RUN: igc_opt --typed-pointers -igc-wave-shuffle-index-sinking -S < %s | FileCheck %s
9+
; ------------------------------------------------
10+
; WaveShuffleIndexSinking
11+
;
12+
; Verifies if pass correctly moves "main" WaveShuffleIndex function call to
13+
; common dominator basic block, which in this case is bb1.
14+
15+
define void @test(i32* %dst0, i32* %dst1, i32 %a) {
16+
entry:
17+
br label %bb1
18+
19+
bb0:
20+
%ws0 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %a, i32 0, i32 0)
21+
store i32 %ws0, i32* %dst0
22+
br label %exit
23+
24+
bb1:
25+
; CHECK-LABEL: bb1:
26+
; CHECK-NEXT: [[WS0:%.*]] = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %a, i32 0, i32 0)
27+
%ws1 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %a, i32 0, i32 0)
28+
store i32 %ws1, i32* %dst1
29+
br label %bb0
30+
31+
exit:
32+
ret void
33+
}
34+
35+
; Function Attrs: convergent nounwind readnone
36+
declare i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32, i32, i32) #0
37+
38+
attributes #0 = { convergent nounwind readnone }

0 commit comments

Comments
 (0)