Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm64/SVE: Add support to handle predicate registers as callee-trash #104065

Merged
merged 15 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/coreclr/jit/codegencommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ CodeGen::CodeGen(Compiler* theCompiler)
#endif // DEBUG

regSet.tmpInit();
#if defined(TARGET_ARM64)
predicateOffset = 0;
#endif

#ifdef LATE_DISASM
getDisAssembler().disInit(compiler);
Expand Down
7 changes: 6 additions & 1 deletion src/coreclr/jit/codegeninterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ class CodeGenInterface
RegState intRegState;
RegState floatRegState;
NodeInternalRegisters internalRegisters;

#if defined(TARGET_ARM64)
// This is used to track the stack offset of first *Temp*
// predicate register and is then used to produce
// the stack address to ld/st it on the stack.
int predicateOffset;
#endif // TARGET_ARM64
protected:
Compiler* compiler;
bool m_genAlignLoops;
Expand Down
18 changes: 17 additions & 1 deletion src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1430,12 +1430,21 @@ class TempDsc
int tdNum;
BYTE tdSize;
var_types tdType;
#if defined(TARGET_ARM64)
// Only used for TYP_MASK to track the sequence of predicate
// registers temps. We use this to ld/st them from stack
// using `ldr pX, [sp, #seqNum mul vl]
BYTE tdSeqNum;
#endif // TARGET_ARM64

public:
TempDsc(int _tdNum, unsigned _tdSize, var_types _tdType)
TempDsc(int _tdNum, unsigned _tdSize, var_types _tdType, unsigned _tdSeqNum)
: tdNum(_tdNum)
, tdSize((BYTE)_tdSize)
, tdType(_tdType)
#if defined(TARGET_ARM64)
, tdSeqNum((BYTE)_tdSeqNum)
#endif // TARGET_ARM64
{
#ifdef DEBUG
// temps must have a negative number (so they have a different number from all local variables)
Expand Down Expand Up @@ -1484,6 +1493,13 @@ class TempDsc
{
return tdType;
}
#ifdef TARGET_ARM64
unsigned tdTempSeqNum() const
{
assert(varTypeIsMask(tdType));
return tdSeqNum;
}
#endif
};

// Specify compiler data that a phase might modify
Expand Down
11 changes: 10 additions & 1 deletion src/coreclr/jit/compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2708,7 +2708,16 @@ inline
tmpDsc = codeGen->regSet.tmpFindNum(varNum, RegSet::TEMP_USAGE_USED);
}
assert(tmpDsc != nullptr);
varOffset = tmpDsc->tdTempOffs();
#if defined(TARGET_ARM64)
if (varTypeIsMask(tmpDsc->tdTempType()))
{
varOffset = tmpDsc->tdTempSeqNum();
}
else
#endif // TARGET_ARM64
{
varOffset = tmpDsc->tdTempOffs();
}
}
else
{
Expand Down
64 changes: 62 additions & 2 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7884,7 +7884,36 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
isSimple = false;
size = EA_SCALABLE;
attr = size;
fmt = isVectorRegister(reg1) ? IF_SVE_IE_2A : IF_SVE_ID_2A;
if (isPredicateRegister(reg1))
{
assert(offs == 0);
// For predicate, generate based of rsGetRsvdReg()
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
regNumber rsvdReg = codeGen->rsGetRsvdReg();

if (varx >= 0)
{
// local

// add rsvd, fp, #imm
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
// str p0, [rsvd, #0, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);
}
else
{
// temp

// `base` contains seqNum and offs = 0, so imm contains seqNum
// add rsvd, fp #predicateStartOffset
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, codeGen->predicateOffset);
// str p0, [rsvd, #imm, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, imm);
}
return;
}

assert(isVectorRegister(reg1));
fmt = IF_SVE_IE_2A;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, I wonder if this code (for SVE vectors) should be refactored call out to an emit_R_R_I function instead of falling into the non-sve code below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.


// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
Expand Down Expand Up @@ -8135,7 +8164,38 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
isSimple = false;
size = EA_SCALABLE;
attr = size;
fmt = isVectorRegister(reg1) ? IF_SVE_JH_2A : IF_SVE_JG_2A;

if (isPredicateRegister(reg1))
{
assert(offs == 0);

// For predicate, generate based of rsGetRsvdReg()
regNumber rsvdReg = codeGen->rsGetRsvdReg();

if (varx >= 0)
{
// local

// add rsvd, fp, #imm
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
// str p0, [rsvd, #0, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);
}
else
{
// temp

// `base` contains seqNum and offs = 0, so imm contains seqNum
// add rsvd, fp #predicateStartOffset
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, codeGen->predicateOffset);
// str p0, [rsvd, #seqNum, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, imm);
}
return;
}

assert(isVectorRegister(reg1));
fmt = IF_SVE_JH_2A;

// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
Expand Down
5 changes: 5 additions & 0 deletions src/coreclr/jit/emitarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,11 @@ inline static bool isHighPredicateRegister(regNumber reg)
return (reg >= REG_PREDICATE_HIGH_FIRST) && (reg <= REG_PREDICATE_HIGH_LAST);
}

inline static bool isMaskReg(regNumber reg)
{
return isPredicateRegister(reg);
}

inline static bool isEvenRegister(regNumber reg)
{
if (isGeneralRegister(reg))
Expand Down
10 changes: 9 additions & 1 deletion src/coreclr/jit/lclvars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5593,7 +5593,7 @@ unsigned Compiler::lvaGetMaxSpillTempSize()
* Doing this all in one pass is 'hard'. So instead we do it in 2 basic passes:
* 1. Assign all the offsets relative to the Virtual '0'. Offsets above (the
* incoming arguments) are positive. Offsets below (everything else) are
* negative. This pass also calcuates the total frame size (between Caller's
* negative. This pass also calculates the total frame size (between Caller's
* SP/return address and the Ambient SP).
* 2. Figure out where to place the frame pointer, and then adjust the offsets
* as needed for the final stack size and whether the offset is frame pointer
Expand Down Expand Up @@ -5872,6 +5872,14 @@ void Compiler::lvaFixVirtualFrameOffsets()
for (TempDsc* temp = codeGen->regSet.tmpListBeg(); temp != nullptr; temp = codeGen->regSet.tmpListNxt(temp))
{
temp->tdAdjustTempOffs(delta);
#if defined(TARGET_ARM64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some guarantee that all the predicate temps end up adjacent on this list? Otherwise it seems like this indexing scheme might not work out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you see below, we iterate over all the type and call tmpPreAllocateTemps with the number of slots we need for that type.

for (int i = 0; i < TYP_COUNT; i++)
{
if (var_types(i) != RegSet::tmpNormalizeType(var_types(i)))
{
// Only normalized types should have anything in the maxSpill array.
// We assume here that if type 'i' does not normalize to itself, then
// nothing else normalizes to 'i', either.
assert(maxSpill[i] == 0);
}
if (maxSpill[i] != 0)
{
JITDUMP(" %s: %d\n", varTypeName(var_types(i)), maxSpill[i]);
compiler->codeGen->regSet.tmpPreAllocateTemps(var_types(i), maxSpill[i]);
}
}

In tmpPreAllocateTemps(), we iterate through the number of slots we want to allocate and create them:

for (unsigned i = 0; i < count; i++)
{
tmpCount++;
tmpSize += size;
#ifdef TARGET_ARM
if (type == TYP_DOUBLE)
{
// Adjust tmpSize to accommodate possible alignment padding.
// Note that at this point the offsets aren't yet finalized, so we don't yet know if it will be required.
tmpSize += TARGET_POINTER_SIZE;
}
#endif // TARGET_ARM
TempDsc* temp = new (m_rsCompiler, CMK_Unknown) TempDsc(-((int)tmpCount), size, type);

if (varTypeIsMask(temp->tdTempType()) && temp->tdTempSeqNum() == 0)
{
// For the first register, store the offset, which we will use to
// generate the offsets for subsequent temp mask registers
codeGen->predicateOffset = temp->tdTempOffs();
}
#endif
}

lvaCachedGenericContextArgOffs += delta;
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,13 @@ class RegRecord : public Referenceable
{
registerType = FloatRegisterType;
}
#if defined(TARGET_XARCH) && defined(FEATURE_SIMD)
#if defined(FEATURE_MASKED_HW_INTRINSICS)
else
{
assert(emitter::isMaskReg(reg));
registerType = MaskRegisterType;
}
#endif
#endif // FEATURE_MASKED_HW_INTRINSICS
regNum = reg;
isCalleeSave = ((RBM_CALLEE_SAVED & genRegMask(reg)) != 0);
}
Expand Down
7 changes: 5 additions & 2 deletions src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call)

#else
killMask.RemoveRegsetForType(RBM_FLT_CALLEE_TRASH.GetFloatRegSet(), FloatRegisterType);
#if defined(TARGET_ARM64)
killMask.RemoveRegsetForType(RBM_MSK_CALLEE_TRASH.GetFloatRegSet(), MaskRegisterType);
#endif // TARGET_ARM64
#endif // TARGET_XARCH
}
#ifdef TARGET_ARM
Expand Down Expand Up @@ -1148,8 +1151,8 @@ bool LinearScan::buildKillPositionsForNode(GenTree* tree, LsraLocation currentLo
{
continue;
}
Interval* interval = getIntervalForLocalVar(varIndex);
const bool isCallKill = ((killMask == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
Interval* interval = getIntervalForLocalVar(varIndex);
const bool isCallKill = ((killMask.getLow() == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
SingleTypeRegSet regsKillMask = killMask.GetRegSetForType(interval->registerType);

if (isCallKill)
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/regset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ void RegSet::tmpPreAllocateTemps(var_types type, unsigned count)
}
#endif // TARGET_ARM

TempDsc* temp = new (m_rsCompiler, CMK_Unknown) TempDsc(-((int)tmpCount), size, type);
TempDsc* temp = new (m_rsCompiler, CMK_Unknown) TempDsc(-((int)tmpCount), size, type, i);

#ifdef DEBUG
if (m_rsCompiler->verbose)
Expand Down
17 changes: 8 additions & 9 deletions src/coreclr/jit/targetarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,15 @@
#define RBM_FLT_CALLEE_SAVED (RBM_V8|RBM_V9|RBM_V10|RBM_V11|RBM_V12|RBM_V13|RBM_V14|RBM_V15)
#define RBM_FLT_CALLEE_TRASH (RBM_V0|RBM_V1|RBM_V2|RBM_V3|RBM_V4|RBM_V5|RBM_V6|RBM_V7|RBM_V16|RBM_V17|RBM_V18|RBM_V19|RBM_V20|RBM_V21|RBM_V22|RBM_V23|RBM_V24|RBM_V25|RBM_V26|RBM_V27|RBM_V28|RBM_V29|RBM_V30|RBM_V31)

#define RBM_LOWMASK (RBM_P0|RBM_P1|RBM_P2|RBM_P3|RBM_P4|RBM_P5|RBM_P6|RBM_P7)
#define RBM_HIGHMASK (RBM_P8|RBM_P9|RBM_P10| RBM_P11|RBM_P12|RBM_P13|RBM_P14|RBM_P15)
#define RBM_ALLMASK (RBM_LOWMASK|RBM_HIGHMASK)

#define RBM_MSK_CALLEE_SAVED (0)
#define RBM_MSK_CALLEE_TRASH RBM_ALLMASK
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somewhere I should just zero it out if we are not running on SVE machine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the TP cost coming from the additional killed registers? I assume that's because we don't have a predicate registers equivalent of compFloatingPointUsed.

I wonder if you could just add a case for predicate registers here:

compiler->compFloatingPointUsed = true;

And then during allocation, mask out the predicate registers when processing kills if no predicate registers were used.

We would still be creating additional RegRecords though, but maybe this helps a bit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would still be creating additional RegRecords though, but maybe this helps a bit.

Actually I guess we were creating those RegRecords even before this PR, so I imagine it would help quite a bit for this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am already doing it in https://github.com/dotnet/runtime/pull/104065/files#diff-ad66a6bcf1fd550d5ad10d995c03218afbbc39463d36e1f2a224f9ca070a2f99R858-R860. Predicate registers exist only in presence of floating point usage. Yes, we do the newly added extra predicate registers in processKills() and that's what show up impacting TP. For non-sve arm64 machine, we don't have to iterate through them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somewhere I should just zero it out if we are not running on SVE machine.

although note that when we altjit, we say that "sve capability enable", so we will see predicate registers and will process them during kills. The TP information will be misleading for those cases, but I will add this anyway so that on non-sve arm64 machine, we do not process them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the use of predicate registers is going to be much more rare than using float registers, hence adding this extra check would help regardless.

I will add this anyway so that on non-sve arm64 machine, we do not process them.

I don't see a good reason to try optimizing for non-SVE machines. In the future we would expect most arm64 machines to be SVE enabled, right?

I think we should rather optimize for the common case of "predicate registers not used". It should be possible now that we are only creating oneRefTypeKill per call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good reason to try optimizing for non-SVE machines. In the future we would expect most arm64 machines to be SVE enabled, right?

Yes

I think the use of predicate registers is going to be much more rare than using float registers, hence adding this extra check would help regardless.

Agree. I will do a separate pass for it. #104157 to track it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good reason to try optimizing for non-SVE machines

Thinking about this a bit more, I think I will just revert bcfd8a8 and will do it properly in #104157


#define RBM_CALLEE_SAVED (RBM_INT_CALLEE_SAVED | RBM_FLT_CALLEE_SAVED)
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH)
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH)

#define REG_DEFAULT_HELPER_CALL_TARGET REG_R12
#define RBM_DEFAULT_HELPER_CALL_TARGET RBM_R12
Expand Down Expand Up @@ -146,14 +153,6 @@
#define REG_JUMP_THUNK_PARAM REG_R12
#define RBM_JUMP_THUNK_PARAM RBM_R12

#define RBM_LOWMASK (RBM_P0 | RBM_P1 | RBM_P2 | RBM_P3 | RBM_P4 | RBM_P5 | RBM_P6 | RBM_P7)
#define RBM_HIGHMASK (RBM_P8 | RBM_P9 | RBM_P10 | RBM_P11 | RBM_P12 | RBM_P13 | RBM_P14 | RBM_P15)
#define RBM_ALLMASK (RBM_LOWMASK | RBM_HIGHMASK)

// TODO-SVE: Fix when adding predicate register allocation
#define RBM_MSK_CALLEE_SAVED (0)
#define RBM_MSK_CALLEE_TRASH (0)

// ARM64 write barrier ABI (see vm\arm64\asmhelpers.asm, vm\arm64\asmhelpers.S):
// CORINFO_HELP_ASSIGN_REF (JIT_WriteBarrier), CORINFO_HELP_CHECKED_ASSIGN_REF (JIT_CheckedWriteBarrier):
// On entry:
Expand Down
Loading