Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm64/SVE: Add support to handle predicate registers as callee-trash #104065

Merged
merged 15 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7884,7 +7884,22 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
isSimple = false;
size = EA_SCALABLE;
attr = size;
fmt = isVectorRegister(reg1) ? IF_SVE_IE_2A : IF_SVE_ID_2A;
if (isPredicateRegister(reg1))
{
assert(offs == 0);
// For predicate, generate based off rsGetRsvdReg()
regNumber rsvdReg = codeGen->rsGetRsvdReg();

// add rsvd, fp, #imm
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
// str p0, [rsvd, #0, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);

return;
}

assert(isVectorRegister(reg1));
fmt = IF_SVE_IE_2A;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, I wonder if this code (for SVE vectors) should be refactored call out to an emit_R_R_I function instead of falling into the non-sve code below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.


// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
Expand Down Expand Up @@ -8138,7 +8153,24 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
isSimple = false;
size = EA_SCALABLE;
attr = size;
fmt = isVectorRegister(reg1) ? IF_SVE_JH_2A : IF_SVE_JG_2A;

if (isPredicateRegister(reg1))
{
assert(offs == 0);

// For predicate, generate based off rsGetRsvdReg()
regNumber rsvdReg = codeGen->rsGetRsvdReg();

// add rsvd, fp, #imm
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
// str p0, [rsvd, #0, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);

return;
}

assert(isVectorRegister(reg1));
fmt = IF_SVE_JH_2A;

// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
Expand Down
5 changes: 5 additions & 0 deletions src/coreclr/jit/emitarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,11 @@ inline static bool isHighPredicateRegister(regNumber reg)
return (reg >= REG_PREDICATE_HIGH_FIRST) && (reg <= REG_PREDICATE_HIGH_LAST);
}

inline static bool isMaskReg(regNumber reg)
{
return isPredicateRegister(reg);
}

inline static bool isEvenRegister(regNumber reg)
{
if (isGeneralRegister(reg))
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/lclvars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5593,7 +5593,7 @@ unsigned Compiler::lvaGetMaxSpillTempSize()
* Doing this all in one pass is 'hard'. So instead we do it in 2 basic passes:
* 1. Assign all the offsets relative to the Virtual '0'. Offsets above (the
* incoming arguments) are positive. Offsets below (everything else) are
* negative. This pass also calcuates the total frame size (between Caller's
* negative. This pass also calculates the total frame size (between Caller's
* SP/return address and the Ambient SP).
* 2. Figure out where to place the frame pointer, and then adjust the offsets
* as needed for the final stack size and whether the offset is frame pointer
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,13 @@ class RegRecord : public Referenceable
{
registerType = FloatRegisterType;
}
#if defined(TARGET_XARCH) && defined(FEATURE_SIMD)
#if defined(FEATURE_MASKED_HW_INTRINSICS)
else
{
assert(emitter::isMaskReg(reg));
registerType = MaskRegisterType;
}
#endif
#endif // FEATURE_MASKED_HW_INTRINSICS
regNum = reg;
isCalleeSave = ((RBM_CALLEE_SAVED & genRegMask(reg)) != 0);
}
Expand Down
7 changes: 5 additions & 2 deletions src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call)

#else
killMask.RemoveRegsetForType(RBM_FLT_CALLEE_TRASH.GetFloatRegSet(), FloatRegisterType);
#if defined(TARGET_ARM64)
killMask.RemoveRegsetForType(RBM_MSK_CALLEE_TRASH.GetPredicateRegSet(), MaskRegisterType);
#endif // TARGET_ARM64
#endif // TARGET_XARCH
}
#ifdef TARGET_ARM
Expand Down Expand Up @@ -1148,8 +1151,8 @@ bool LinearScan::buildKillPositionsForNode(GenTree* tree, LsraLocation currentLo
{
continue;
}
Interval* interval = getIntervalForLocalVar(varIndex);
const bool isCallKill = ((killMask == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
Interval* interval = getIntervalForLocalVar(varIndex);
const bool isCallKill = ((killMask.getLow() == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
SingleTypeRegSet regsKillMask = killMask.GetRegSetForType(interval->registerType);

if (isCallKill)
Expand Down
17 changes: 8 additions & 9 deletions src/coreclr/jit/targetarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,15 @@
#define RBM_FLT_CALLEE_SAVED (RBM_V8|RBM_V9|RBM_V10|RBM_V11|RBM_V12|RBM_V13|RBM_V14|RBM_V15)
#define RBM_FLT_CALLEE_TRASH (RBM_V0|RBM_V1|RBM_V2|RBM_V3|RBM_V4|RBM_V5|RBM_V6|RBM_V7|RBM_V16|RBM_V17|RBM_V18|RBM_V19|RBM_V20|RBM_V21|RBM_V22|RBM_V23|RBM_V24|RBM_V25|RBM_V26|RBM_V27|RBM_V28|RBM_V29|RBM_V30|RBM_V31)

#define RBM_LOWMASK (RBM_P0|RBM_P1|RBM_P2|RBM_P3|RBM_P4|RBM_P5|RBM_P6|RBM_P7)
#define RBM_HIGHMASK (RBM_P8|RBM_P9|RBM_P10| RBM_P11|RBM_P12|RBM_P13|RBM_P14|RBM_P15)
#define RBM_ALLMASK (RBM_LOWMASK|RBM_HIGHMASK)

#define RBM_MSK_CALLEE_SAVED (0)
#define RBM_MSK_CALLEE_TRASH RBM_ALLMASK
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somewhere I should just zero it out if we are not running on SVE machine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the TP cost coming from the additional killed registers? I assume that's because we don't have a predicate registers equivalent of compFloatingPointUsed.

I wonder if you could just add a case for predicate registers here:

compiler->compFloatingPointUsed = true;

And then during allocation, mask out the predicate registers when processing kills if no predicate registers were used.

We would still be creating additional RegRecords though, but maybe this helps a bit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would still be creating additional RegRecords though, but maybe this helps a bit.

Actually I guess we were creating those RegRecords even before this PR, so I imagine it would help quite a bit for this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am already doing it in https://github.com/dotnet/runtime/pull/104065/files#diff-ad66a6bcf1fd550d5ad10d995c03218afbbc39463d36e1f2a224f9ca070a2f99R858-R860. Predicate registers exist only in presence of floating point usage. Yes, we do the newly added extra predicate registers in processKills() and that's what show up impacting TP. For non-sve arm64 machine, we don't have to iterate through them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somewhere I should just zero it out if we are not running on SVE machine.

although note that when we altjit, we say that "sve capability enable", so we will see predicate registers and will process them during kills. The TP information will be misleading for those cases, but I will add this anyway so that on non-sve arm64 machine, we do not process them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the use of predicate registers is going to be much more rare than using float registers, hence adding this extra check would help regardless.

I will add this anyway so that on non-sve arm64 machine, we do not process them.

I don't see a good reason to try optimizing for non-SVE machines. In the future we would expect most arm64 machines to be SVE enabled, right?

I think we should rather optimize for the common case of "predicate registers not used". It should be possible now that we are only creating oneRefTypeKill per call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good reason to try optimizing for non-SVE machines. In the future we would expect most arm64 machines to be SVE enabled, right?

Yes

I think the use of predicate registers is going to be much more rare than using float registers, hence adding this extra check would help regardless.

Agree. I will do a separate pass for it. #104157 to track it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good reason to try optimizing for non-SVE machines

Thinking about this a bit more, I think I will just revert bcfd8a8 and will do it properly in #104157


#define RBM_CALLEE_SAVED (RBM_INT_CALLEE_SAVED | RBM_FLT_CALLEE_SAVED)
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH)
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH)

#define REG_DEFAULT_HELPER_CALL_TARGET REG_R12
#define RBM_DEFAULT_HELPER_CALL_TARGET RBM_R12
Expand Down Expand Up @@ -146,14 +153,6 @@
#define REG_JUMP_THUNK_PARAM REG_R12
#define RBM_JUMP_THUNK_PARAM RBM_R12

#define RBM_LOWMASK (RBM_P0 | RBM_P1 | RBM_P2 | RBM_P3 | RBM_P4 | RBM_P5 | RBM_P6 | RBM_P7)
#define RBM_HIGHMASK (RBM_P8 | RBM_P9 | RBM_P10 | RBM_P11 | RBM_P12 | RBM_P13 | RBM_P14 | RBM_P15)
#define RBM_ALLMASK (RBM_LOWMASK | RBM_HIGHMASK)

// TODO-SVE: Fix when adding predicate register allocation
#define RBM_MSK_CALLEE_SAVED (0)
#define RBM_MSK_CALLEE_TRASH (0)

// ARM64 write barrier ABI (see vm\arm64\asmhelpers.asm, vm\arm64\asmhelpers.S):
// CORINFO_HELP_ASSIGN_REF (JIT_WriteBarrier), CORINFO_HELP_CHECKED_ASSIGN_REF (JIT_CheckedWriteBarrier):
// On entry:
Expand Down
Loading