Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle overlapped groups of bounds checks #112660

Merged
merged 4 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 77 additions & 28 deletions src/coreclr/jit/rangecheckcloning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,18 @@
// Arguments:
// comp - The compiler instance
// statement - The statement containing the bounds check
// bndChkNode - The bounds check node
// bndChkParentNode - The parent node of the bounds check node (either null or COMMA)
// statementIdx - The index of the statement in the block
// bndChk - The bounds check node (its use edge)
//
// Return Value:
// true if the initialization was successful, false otherwise.
//
bool BoundsCheckInfo::Initialize(const Compiler* comp, Statement* statement, GenTree** bndChk)
bool BoundsCheckInfo::Initialize(const Compiler* comp, Statement* statement, int statementIdx, GenTree** bndChk)
{
assert((bndChk != nullptr) && ((*bndChk) != nullptr));

stmt = statement;
stmtIdx = statementIdx;
bndChkUse = bndChk;
idxVN = comp->vnStore->VNConservativeNormalValue(BndChk()->GetIndex()->gtVNPair);
lenVN = comp->vnStore->VNConservativeNormalValue(BndChk()->GetArrayLength()->gtVNPair);
Expand Down Expand Up @@ -98,10 +99,9 @@ bool BoundsCheckInfo::Initialize(const Compiler* comp, Statement* statement, Gen
// RemoveBoundsChk - Remove the given bounds check from the statement and the block.
//
// Arguments:
// comp - compiler instance
// check - bounds check node to remove
// comma - check's parent node (either null or COMMA)
// stmt - statement containing the bounds check
// comp - compiler instance
// treeUse - the bounds check node to remove (its use edge)
// stmt - the statement containing the bounds check
//
static void RemoveBoundsChk(Compiler* comp, GenTree** treeUse, Statement* stmt)
{
Expand Down Expand Up @@ -155,18 +155,21 @@ static void RemoveBoundsChk(Compiler* comp, GenTree** treeUse, Statement* stmt)
// comp - The compiler instance
// block - The block to clone
// bndChkStack - The stack of bounds checks to clone
// lastStmt - The last statement in the block (the block is split after this statement)
//
// Return Value:
// The block containing the fast path.
// The next block to visit after the cloning.
//
static BasicBlock* optRangeCheckCloning_DoClone(Compiler* comp, BasicBlock* block, BoundsCheckInfoStack* bndChkStack)
static BasicBlock* optRangeCheckCloning_DoClone(Compiler* comp,
BasicBlock* block,
BoundsCheckInfoStack* bndChkStack,
Statement* lastStmt)
{
assert(block != nullptr);
assert(bndChkStack->Height() > 0);

// The bound checks are in the execution order (top of the stack is the last check)
BoundsCheckInfo firstCheck = bndChkStack->Bottom();
BoundsCheckInfo lastCheck = bndChkStack->Top();
BasicBlock* prevBb = block;

// First, split the block at the first bounds check using gtSplitTree (via fgSplitBlockBeforeTree):
Expand All @@ -187,7 +190,7 @@ static BasicBlock* optRangeCheckCloning_DoClone(Compiler* comp, BasicBlock* bloc
// Now split the block at the last bounds check using fgSplitBlockAfterStatement:
// TODO-RangeCheckCloning: call gtSplitTree for lastBndChkStmt as well, to cut off
// the stuff we don't have to clone.
BasicBlock* lastBb = comp->fgSplitBlockAfterStatement(fastpathBb, lastCheck.stmt);
BasicBlock* lastBb = comp->fgSplitBlockAfterStatement(fastpathBb, lastStmt);

DebugInfo debugInfo = fastpathBb->firstStmt()->GetDebugInfo();

Expand Down Expand Up @@ -359,6 +362,7 @@ class BoundsChecksVisitor final : public GenTreeVisitor<BoundsChecksVisitor>
{
Statement* m_stmt;
ArrayStack<BoundCheckLocation>* m_boundsChks;
int m_stmtIdx;

public:
enum
Expand All @@ -368,10 +372,14 @@ class BoundsChecksVisitor final : public GenTreeVisitor<BoundsChecksVisitor>
UseExecutionOrder = true
};

BoundsChecksVisitor(Compiler* compiler, Statement* stmt, ArrayStack<BoundCheckLocation>* bndChkLocations)
BoundsChecksVisitor(Compiler* compiler,
Statement* stmt,
int stmtIdx,
ArrayStack<BoundCheckLocation>* bndChkLocations)
: GenTreeVisitor(compiler)
, m_stmt(stmt)
, m_boundsChks(bndChkLocations)
, m_stmtIdx(stmtIdx)
{
}

Expand All @@ -389,7 +397,7 @@ class BoundsChecksVisitor final : public GenTreeVisitor<BoundsChecksVisitor>
{
if ((*use)->OperIs(GT_BOUNDS_CHECK))
{
m_boundsChks->Push(BoundCheckLocation(m_stmt, use));
m_boundsChks->Push(BoundCheckLocation(m_stmt, use, m_stmtIdx));
}
return fgWalkResult::WALK_CONTINUE;
}
Expand All @@ -401,6 +409,7 @@ class BoundsChecksVisitor final : public GenTreeVisitor<BoundsChecksVisitor>
// the bounds checks.
//
// Arguments:
// comp - The compiler instance
// bndChks - The stack of bounds checks
//
// Return Value:
Expand Down Expand Up @@ -499,8 +508,10 @@ PhaseStatus Compiler::optRangeCheckCloning()
bndChkLocations.Reset();
bndChkMap.RemoveAll();

int stmtIdx = -1;
for (Statement* const stmt : block->Statements())
{
stmtIdx++;
if (block->HasTerminator() && (stmt == block->lastStmt()))
{
// TODO-RangeCheckCloning: Splitting these blocks at the last statements
Expand All @@ -510,7 +521,7 @@ PhaseStatus Compiler::optRangeCheckCloning()

// Now just record all the bounds checks in the block (in the execution order)
//
BoundsChecksVisitor visitor(this, stmt, &bndChkLocations);
BoundsChecksVisitor visitor(this, stmt, stmtIdx, &bndChkLocations);
visitor.WalkTree(stmt->GetRootNodePointer(), nullptr);
}

Expand All @@ -528,7 +539,7 @@ PhaseStatus Compiler::optRangeCheckCloning()
{
BoundCheckLocation loc = bndChkLocations.Bottom(i);
BoundsCheckInfo bci{};
if (bci.Initialize(this, loc.stmt, loc.bndChkUse))
if (bci.Initialize(this, loc.stmt, loc.stmtIdx, loc.bndChkUse))
{
IdxLenPair key(bci.idxVN, bci.lenVN);
BoundsCheckInfoStack** value = bndChkMap.LookupPointerOrAdd(key, nullptr);
Expand All @@ -552,34 +563,72 @@ PhaseStatus Compiler::optRangeCheckCloning()
}

// Now choose the largest group of bounds checks (the one with the most checks)
BoundsCheckInfoStack* largestGroup = nullptr;
ArrayStack<BoundsCheckInfoStack*> groups(getAllocator(CMK_RangeCheckCloning));

for (BoundsCheckInfoMap::Node* keyValuePair : BoundsCheckInfoMap::KeyValueIteration(&bndChkMap))
{
ArrayStack<BoundsCheckInfo>* value = keyValuePair->GetValue();
if ((largestGroup == nullptr) || (value->Height() > largestGroup->Height()))
if ((value->Height() >= MIN_CHECKS_PER_GROUP) && !DoesComplexityExceed(this, value))
{
if (DoesComplexityExceed(this, value))
{
continue;
}
largestGroup = value;
groups.Push(value);
}
}

if (largestGroup == nullptr)
if (groups.Height() == 0)
{
JITDUMP("No suitable group of bounds checks in the block - bail out.\n");
continue;
}

if (largestGroup->Height() < MIN_CHECKS_PER_GROUP)
// We have multiple groups of bounds checks in the block.
// let's pick a group that appears first in the block and the one whose last bounds check
// appears last in the block.
//
BoundsCheckInfoStack* firstGroup = groups.Top();
BoundsCheckInfoStack* lastGroup = groups.Top();
for (int i = 0; i < groups.Height(); i++)
{
JITDUMP("Not enough bounds checks in the largest group - bail out.\n");
continue;
BoundsCheckInfoStack* group = groups.Bottom(i);
int firstStmt = group->Bottom().stmtIdx;
int secondStmt = group->Top().stmtIdx;
if (firstStmt < firstGroup->Bottom().stmtIdx)
{
firstGroup = group;
}
if (secondStmt > lastGroup->Top().stmtIdx)
{
lastGroup = group;
}
}

// We're going to clone for the first group.
// But let's see if we can extend the end of the group so future iterations
// can fit more groups in the same block.
//
Statement* lastStmt = firstGroup->Top().stmt;

int firstGroupStarts = firstGroup->Bottom().stmtIdx;
int firstGroupEnds = firstGroup->Top().stmtIdx;
int lastGroupStarts = lastGroup->Bottom().stmtIdx;
int lastGroupEnds = lastGroup->Top().stmtIdx;

// The only requirement is that both groups must overlap - we don't want to
// end up cloning unrelated statements between them (not a correctness issue,
// just a heuristic to avoid cloning too much).
//
if (firstGroupEnds < lastGroupEnds && firstGroupEnds >= lastGroupStarts)
{
lastStmt = lastGroup->Top().stmt;
}

JITDUMP("Cloning bounds checks in " FMT_BB "\n", block->bbNum);
block = optRangeCheckCloning_DoClone(this, block, largestGroup);
JITDUMP("Cloning bounds checks in " FMT_BB " from " FMT_STMT " to " FMT_STMT "\n", block->bbNum,
firstGroup->Bottom().stmt->GetID(), lastStmt->GetID());

BasicBlock* nextBbToVisit = optRangeCheckCloning_DoClone(this, block, firstGroup, lastStmt);
assert(nextBbToVisit != nullptr);
// optRangeCheckCloning_DoClone wants us to visit nextBbToVisit next
block = nextBbToVisit->Prev();
assert(block != nullptr);
Comment on lines +627 to +631
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way you could do this without requiring the asserts would be:

BasicBlock* next;
for (BasicBlock* block = fgFirstBB; block != nullptr; block = next)
{
  next = block->Next();
  ...
  next = optRangeCheckCloning_DoClone(...);
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes sense, I was thinking about goto 😄 I'll change in a follow up to avoid spinning CI

modified = true;
}

Expand Down
9 changes: 7 additions & 2 deletions src/coreclr/jit/rangecheckcloning.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ struct BoundCheckLocation
{
Statement* stmt;
GenTree** bndChkUse;
int stmtIdx;

BoundCheckLocation(Statement* stmt, GenTree** bndChkUse)
BoundCheckLocation(Statement* stmt, GenTree** bndChkUse, int stmtIdx)
: stmt(stmt)
, bndChkUse(bndChkUse)
, stmtIdx(stmtIdx)
{
assert(stmt != nullptr);
assert((bndChkUse != nullptr));
assert((*bndChkUse) != nullptr);
assert((*bndChkUse)->OperIs(GT_BOUNDS_CHECK));
assert(stmtIdx >= 0);
}
};

Expand All @@ -41,17 +44,19 @@ struct BoundsCheckInfo
ValueNum lenVN;
ValueNum idxVN;
int offset;
int stmtIdx;

BoundsCheckInfo()
: stmt(nullptr)
, bndChkUse(nullptr)
, lenVN(ValueNumStore::NoVN)
, idxVN(ValueNumStore::NoVN)
, offset(0)
, stmtIdx(0)
{
}

bool Initialize(const Compiler* comp, Statement* statement, GenTree** bndChkUse);
bool Initialize(const Compiler* comp, Statement* statement, int statementIdx, GenTree** bndChkUse);

GenTreeBoundsChk* BndChk() const
{
Expand Down
Loading