Skip to content

Commit

Permalink
JIT: Optimize SequenceEqual to use ccmp on ARM64 (#92810)
Browse files Browse the repository at this point in the history
In the original PR we could not get this this working due to some
conservative interference. This now does the right thing with #92710
merged.

Also change LowerCallMemcmp/LowerCallMemmove to return next node to
lower just to align it a bit more with other functions.
  • Loading branch information
jakobbotsch authored Sep 29, 2023
1 parent e3d37a8 commit 1c6d909
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 64 deletions.
145 changes: 83 additions & 62 deletions src/coreclr/jit/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1798,11 +1798,12 @@ GenTree* Lowering::AddrGen(void* addr)
//
// Arguments:
// tree - GenTreeCall node to replace with STORE_BLK
// next - [out] Next node to lower if this function returns true
//
// Return Value:
// nullptr if no changes were made
// false if no changes were made
//
GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
bool Lowering::LowerCallMemmove(GenTreeCall* call, GenTree** next)
{
JITDUMP("Considering Memmove [%06d] for unrolling.. ", comp->dspTreeID(call))
assert(comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_Buffer_Memmove);
Expand All @@ -1812,7 +1813,7 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
if (comp->info.compHasNextCallRetAddr)
{
JITDUMP("compHasNextCallRetAddr=true so we won't be able to remove the call - bail out.\n")
return nullptr;
return false;
}

GenTree* lengthArg = call->gtArgs.GetUserArgByIndex(2)->GetNode();
Expand Down Expand Up @@ -1857,7 +1858,9 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)

JITDUMP("\nNew tree:\n")
DISPTREE(storeBlk);
return storeBlk;
// TODO: This skips lowering srcBlk and storeBlk.
*next = storeBlk->gtNext;
return true;
}
else
{
Expand All @@ -1868,7 +1871,7 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
{
JITDUMP("size is not a constant.\n")
}
return nullptr;
return false;
}

//------------------------------------------------------------------------
Expand All @@ -1877,11 +1880,12 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
//
// Arguments:
// tree - GenTreeCall node to unroll as memcmp
// next - [out] Next node to lower if this function returns true
//
// Return Value:
// nullptr if no changes were made
// false if no changes were made
//
GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
bool Lowering::LowerCallMemcmp(GenTreeCall* call, GenTree** next)
{
JITDUMP("Considering Memcmp [%06d] for unrolling.. ", comp->dspTreeID(call))
assert(comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_SpanHelpers_SequenceEqual);
Expand All @@ -1891,13 +1895,13 @@ GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
if (!comp->opts.OptimizationEnabled())
{
JITDUMP("Optimizations aren't allowed - bail out.\n")
return nullptr;
return false;
}

if (comp->info.compHasNextCallRetAddr)
{
JITDUMP("compHasNextCallRetAddr=true so we won't be able to remove the call - bail out.\n")
return nullptr;
return false;
}

GenTree* lengthArg = call->gtArgs.GetUserArgByIndex(2)->GetNode();
Expand Down Expand Up @@ -2004,9 +2008,8 @@ GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
GenTree* rIndir = comp->gtNewIndir(loadType, rArg);
result = newBinaryOp(comp, GT_EQ, TYP_INT, lIndir, rIndir);

BlockRange().InsertAfter(lArg, lIndir);
BlockRange().InsertAfter(rArg, rIndir);
BlockRange().InsertBefore(call, result);
BlockRange().InsertBefore(call, lIndir, rIndir, result);
*next = lIndir;
}
else
{
Expand All @@ -2020,51 +2023,77 @@ GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
GenTree* rArgClone = comp->gtNewLclvNode(rArgUse.ReplaceWithLclVar(comp), genActualType(rArg));
BlockRange().InsertBefore(call, lArgClone, rArgClone);

// We're going to emit something like the following:
//
// bool result = ((*(int*)leftArg ^ *(int*)rightArg) |
// (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0;
//
// ^ in the given example we unroll for length=5
//
// In IR:
//
// * EQ int
// +--* OR int
// | +--* XOR int
// | | +--* IND int
// | | | \--* LCL_VAR byref V1
// | | \--* IND int
// | | \--* LCL_VAR byref V2
// | \--* XOR int
// | +--* IND int
// | | \--* ADD byref
// | | +--* LCL_VAR byref V1
// | | \--* CNS_INT int 1
// | \--* IND int
// | \--* ADD byref
// | +--* LCL_VAR byref V2
// | \--* CNS_INT int 1
// \--* CNS_INT int 0
//
*next = lArgClone;

GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def());
GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def());
GenTree* lXor = newBinaryOp(comp, GT_XOR, actualLoadType, l1Indir, r1Indir);
GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth, TYP_I_IMPL);
GenTree* l2AddOffs = newBinaryOp(comp, GT_ADD, lArg->TypeGet(), lArgClone, l2Offs);
GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs);
GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same
GenTree* r2Offs = comp->gtNewIconNode(cnsSize - loadWidth, TYP_I_IMPL);
GenTree* r2AddOffs = newBinaryOp(comp, GT_ADD, rArg->TypeGet(), rArgClone, r2Offs);
GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs);
GenTree* rXor = newBinaryOp(comp, GT_XOR, actualLoadType, l2Indir, r2Indir);
GenTree* resultOr = newBinaryOp(comp, GT_OR, actualLoadType, lXor, rXor);
GenTree* zeroCns = comp->gtNewZeroConNode(actualLoadType);
result = newBinaryOp(comp, GT_EQ, TYP_INT, resultOr, zeroCns);

BlockRange().InsertAfter(rArgClone, l1Indir, l2Offs, l2AddOffs, l2Indir);
BlockRange().InsertAfter(l2Indir, r1Indir, r2Offs, r2AddOffs, r2Indir);
BlockRange().InsertAfter(r2Indir, lXor, rXor, resultOr, zeroCns);
BlockRange().InsertAfter(zeroCns, result);

#ifdef TARGET_ARM64
if (!varTypeIsSIMD(loadType))
{
// ARM64 will get efficient ccmp codegen if we emit the normal thing:
//
// bool result = (*(int*)leftArg == *(int)rightArg) & (*(int*)(leftArg + 1) == *(int*)(rightArg
// +
// 1))

GenTree* eq1 = newBinaryOp(comp, GT_EQ, TYP_INT, l1Indir, r1Indir);
GenTree* eq2 = newBinaryOp(comp, GT_EQ, TYP_INT, l2Indir, r2Indir);
result = newBinaryOp(comp, GT_AND, TYP_INT, eq1, eq2);

BlockRange().InsertAfter(r2Indir, eq1, eq2, result);
}
#endif

if (result == nullptr)
{
// We're going to emit something like the following:
//
// bool result = ((*(int*)leftArg ^ *(int*)rightArg) |
// (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0;
//
// ^ in the given example we unroll for length=5
//
// In IR:
//
// * EQ int
// +--* OR int
// | +--* XOR int
// | | +--* IND int
// | | | \--* LCL_VAR byref V1
// | | \--* IND int
// | | \--* LCL_VAR byref V2
// | \--* XOR int
// | +--* IND int
// | | \--* ADD byref
// | | +--* LCL_VAR byref V1
// | | \--* CNS_INT int 1
// | \--* IND int
// | \--* ADD byref
// | +--* LCL_VAR byref V2
// | \--* CNS_INT int 1
// \--* CNS_INT int 0
//
// TODO-CQ: Do this as a general optimization similar to TryLowerAndOrToCCMP.

GenTree* lXor = newBinaryOp(comp, GT_XOR, actualLoadType, l1Indir, r1Indir);
GenTree* rXor = newBinaryOp(comp, GT_XOR, actualLoadType, l2Indir, r2Indir);
GenTree* resultOr = newBinaryOp(comp, GT_OR, actualLoadType, lXor, rXor);
GenTree* zeroCns = comp->gtNewZeroConNode(actualLoadType);
result = newBinaryOp(comp, GT_EQ, TYP_INT, resultOr, zeroCns);

BlockRange().InsertAfter(r2Indir, lXor, rXor, resultOr, zeroCns);
BlockRange().InsertAfter(zeroCns, result);
}
}

JITDUMP("\nUnrolled to:\n");
Expand All @@ -2090,7 +2119,7 @@ GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
arg.GetNode()->SetUnusedValue();
}
}
return lArg;
return true;
}
}
else
Expand All @@ -2102,7 +2131,7 @@ GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
{
JITDUMP("size is not a constant.\n")
}
return nullptr;
return false;
}

// do lowering steps for a call
Expand Down Expand Up @@ -2133,20 +2162,12 @@ GenTree* Lowering::LowerCall(GenTree* node)
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
if (call->gtCallMoreFlags & GTF_CALL_M_SPECIAL_INTRINSIC)
{
GenTree* newNode = nullptr;
NamedIntrinsic ni = comp->lookupNamedIntrinsic(call->gtCallMethHnd);
if (ni == NI_System_Buffer_Memmove)
{
newNode = LowerCallMemmove(call);
}
else if (ni == NI_System_SpanHelpers_SequenceEqual)
GenTree* nextNode = nullptr;
NamedIntrinsic ni = comp->lookupNamedIntrinsic(call->gtCallMethHnd);
if (((ni == NI_System_Buffer_Memmove) && LowerCallMemmove(call, &nextNode)) ||
((ni == NI_System_SpanHelpers_SequenceEqual) && LowerCallMemcmp(call, &nextNode)))
{
newNode = LowerCallMemcmp(call);
}

if (newNode != nullptr)
{
return newNode->gtNext;
return nextNode;
}
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ class Lowering final : public Phase
// Call Lowering
// ------------------------------
GenTree* LowerCall(GenTree* call);
GenTree* LowerCallMemmove(GenTreeCall* call);
GenTree* LowerCallMemcmp(GenTreeCall* call);
bool LowerCallMemmove(GenTreeCall* call, GenTree** next);
bool LowerCallMemcmp(GenTreeCall* call, GenTree** next);
void LowerCFGCall(GenTreeCall* call);
void MoveCFGCallArg(GenTreeCall* call, GenTree* node);
#ifndef TARGET_64BIT
Expand Down

0 comments on commit 1c6d909

Please sign in to comment.