Skip to content

Commit

Permalink
JIT: Make strength reduced IV updates amenable to post-indexed addres…
Browse files Browse the repository at this point in the history
…sing

On arm64 have strength reduction try to insert IV updates after the last
use if that last use is a legal insertion point. This often allows the
backend to use post-indexed addressing to combine the use with the IV
update.
  • Loading branch information
jakobbotsch committed Jul 20, 2024
1 parent b084c08 commit e907ed1
Showing 1 changed file with 156 additions and 9 deletions.
165 changes: 156 additions & 9 deletions src/coreclr/jit/inductionvariableopts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,11 @@ class StrengthReductionContext
bool CheckAdvancedCursors(ArrayStack<CursorInfo>* cursors, ScevAddRec** nextIV);
bool StaysWithinManagedObject(ArrayStack<CursorInfo>* cursors, ScevAddRec* addRec);
bool TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorInfo>* cursors, ScevAddRec* iv);
BasicBlock* FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors);
BasicBlock* FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors, Statement** afterStmt);
BasicBlock* FindPostUseUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors,
BasicBlock* backEdgeDominator,
Statement** afterStmt);
bool InsertionPointPostDominatesUses(BasicBlock* insertionPoint, ArrayStack<CursorInfo>* cursors);

bool StressProfitability()
{
Expand Down Expand Up @@ -2000,7 +2004,8 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorI
return false;
}

BasicBlock* insertionPoint = FindUpdateInsertionPoint(cursors);
Statement* afterStmt;
BasicBlock* insertionPoint = FindUpdateInsertionPoint(cursors, &afterStmt);
if (insertionPoint == nullptr)
{
JITDUMP(" Skipping: could not find a legal insertion point for the new IV update\n");
Expand Down Expand Up @@ -2032,7 +2037,14 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorI
m_comp->gtNewOperNode(GT_ADD, iv->Type, m_comp->gtNewLclVarNode(newPrimaryIV, iv->Type), stepValue);
GenTree* stepStore = m_comp->gtNewTempStore(newPrimaryIV, nextValue);
Statement* stepStmt = m_comp->fgNewStmtFromTree(stepStore);
m_comp->fgInsertStmtNearEnd(insertionPoint, stepStmt);
if (afterStmt != nullptr)
{
m_comp->fgInsertStmtAfter(insertionPoint, afterStmt, stepStmt);
}
else
{
m_comp->fgInsertStmtNearEnd(insertionPoint, stepStmt);
}

JITDUMP(" Inserting step statement in " FMT_BB "\n", insertionPoint->bbNum);
DISPSTMT(stepStmt);
Expand Down Expand Up @@ -2084,22 +2096,27 @@ bool StrengthReductionContext::TryReplaceUsesWithNewPrimaryIV(ArrayStack<CursorI
// of a new primary IV introduced by strength reduction.
//
// Parameters:
// cursors - The list of cursors pointing to uses that are being replaced by
// the new IV
// cursors - The list of cursors pointing to uses that are being replaced by
// the new IV
// afterStmt - [out] Statement to insert the update after. Set to nullptr if
// update should be inserted near the end of the block.
//
// Returns:
// Basic block; the insertion point is the end (before a potential
// terminator) of this basic block. May return null if no insertion point
// could be found.
//
BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors)
BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors, Statement** afterStmt)
{
*afterStmt = nullptr;

// Find insertion point. It needs to post-dominate all uses we are going to
// replace and it needs to dominate all backedges.
// TODO-CQ: Canonicalizing backedges would make this simpler and work in
// more cases.

BasicBlock* insertionPoint = nullptr;

for (FlowEdge* backEdge : m_loop->BackEdges())
{
if (insertionPoint == nullptr)
Expand All @@ -2112,6 +2129,18 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<Cursor
}
}

#ifdef TARGET_ARM64
// For arm64 we try to place the IV update after a use if possible. This
// sets the backend up for post-indexed addressing mode.
BasicBlock* postUseInsertionPoint = FindPostUseUpdateInsertionPoint(cursors, insertionPoint, afterStmt);
if (postUseInsertionPoint != nullptr)
{
JITDUMP(" Found a legal insertion point after a last use of the IV in " FMT_BB " after " FMT_STMT "\n",
postUseInsertionPoint->bbNum, (*afterStmt)->GetID());
return postUseInsertionPoint;
}
#endif

while ((insertionPoint != nullptr) && m_loop->ContainsBlock(insertionPoint) &&
m_loop->MayExecuteBlockMultipleTimesPerIteration(insertionPoint))
{
Expand All @@ -2123,6 +2152,124 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<Cursor
return nullptr;
}

if (!InsertionPointPostDominatesUses(insertionPoint, cursors))
{
return nullptr;
}

JITDUMP(" Found a legal insertion point in " FMT_BB "\n", insertionPoint->bbNum);
return insertionPoint;
}

//------------------------------------------------------------------------
// FindPostUseUpdateInsertionPoint: Try finding an insertion point for the IV
// update that is right after one of the uses of it.
//
// Parameters:
// cursors - The list of cursors pointing to uses that are being replaced by
// the new IV
// backEdgeDominator - A basic block that dominates all backedges
// afterStmt - [out] Statement to insert the update after, if the
// return value is non-null.
//
// Returns:
// nullptr if no such insertion point could be found. Otherwise returns the
// basic block and statement after which the update can be inserted.
//
BasicBlock* StrengthReductionContext::FindPostUseUpdateInsertionPoint(ArrayStack<CursorInfo>* cursors,
BasicBlock* backEdgeDominator,
Statement** afterStmt)
{
BitVecTraits poTraits = m_loop->GetDfsTree()->PostOrderTraits();

#ifdef DEBUG
// We will be relying on the fact that the cursors are ordered in a useful
// way here: loop locals are visited in post order within each basic block,
// meaning that "cursors" has the last uses first for each basic block.
// Assert that here.

BitVec seenBlocks(BitVecOps::MakeEmpty(&poTraits));
for (int i = 1; i < cursors->Height(); i++)
{
CursorInfo& prevCursor = cursors->BottomRef(i - 1);
CursorInfo& cursor = cursors->BottomRef(i);

if (cursor.Block != prevCursor.Block)
{
assert(BitVecOps::TryAddElemD(&poTraits, seenBlocks, prevCursor.Block->bbPostorderNum));
continue;
}

Statement* curStmt = cursor.Stmt;
while ((curStmt != nullptr) && (curStmt != prevCursor.Stmt))
{
curStmt = curStmt->GetNextStmt();
}

assert(curStmt == prevCursor.Stmt);
}
#endif

BitVec blocksWithUses(BitVecOps::MakeEmpty(&poTraits));
for (int i = 0; i < cursors->Height(); i++)
{
CursorInfo& cursor = cursors->BottomRef(i);
BitVecOps::AddElemD(&poTraits, blocksWithUses, cursor.Block->bbPostorderNum);
}

while ((backEdgeDominator != nullptr) && m_loop->ContainsBlock(backEdgeDominator))
{
if (!BitVecOps::IsMember(&poTraits, blocksWithUses, backEdgeDominator->bbPostorderNum))
{
backEdgeDominator = backEdgeDominator->bbIDom;
continue;
}

if (m_loop->MayExecuteBlockMultipleTimesPerIteration(backEdgeDominator))
{
return nullptr;
}

for (int i = 0; i < cursors->Height(); i++)
{
CursorInfo& cursor = cursors->BottomRef(i);
if (cursor.Block != backEdgeDominator)
{
continue;
}

if (!InsertionPointPostDominatesUses(cursor.Block, cursors))
{
return nullptr;
}

*afterStmt = cursor.Stmt;
return cursor.Block;
}
}

return nullptr;
}

//------------------------------------------------------------------------
// InsertionPointPostDominatesUses: Check if a basic block post-dominates all
// locations specified by the cursors.
//
// Parameters:
// insertionPoint - The insertion point
// cursors - Cursors specifying locations
//
// Returns:
// True if so.
//
// Remarks:
// For cursors inside "insertionPoint", the function expects that the
// insertion point is _after_ the use, except if the use is in a terminator
// statement.
//
bool StrengthReductionContext::InsertionPointPostDominatesUses(BasicBlock* insertionPoint,
ArrayStack<CursorInfo>* cursors)
{
for (int i = 0; i < cursors->Height(); i++)
{
CursorInfo& cursor = cursors->BottomRef(i);
Expand All @@ -2131,19 +2278,19 @@ BasicBlock* StrengthReductionContext::FindUpdateInsertionPoint(ArrayStack<Cursor
{
if (insertionPoint->HasTerminator() && (cursor.Stmt == insertionPoint->lastStmt()))
{
return nullptr;
return false;
}
}
else
{
if (!m_loop->IsPostDominatedOnLoopIteration(cursor.Block, insertionPoint))
{
return nullptr;
return false;
}
}
}

return insertionPoint;
return true;
}

//------------------------------------------------------------------------
Expand Down

0 comments on commit e907ed1

Please sign in to comment.