Skip to content

Commit

Permalink
Arm64/SVE: Add support to handle predicate registers as callee-trash (#…
Browse files Browse the repository at this point in the history
…104065)

* Add MSK_CALLEE_TRASH and include it in CALLEE_TRASH

* Assign correct registerType for predicate registers

* Handle the save/restore of predicate registers

* misc changes

* jit format

* Remove handling of Temps and use the same as locals

* use GetPredicateRegSet()

* Disable mask registers if on non-sve

* small change in DbgEnc

* jit format

* Revert "jit format"

This reverts commit 5535c69.

* Revert "small change in DbgEnc"

This reverts commit bb97d80.

* Revert "Disable mask registers if on non-sve"

This reverts commit bcfd8a8.

* minor review feedback
  • Loading branch information
kunalspathak authored Jun 28, 2024
1 parent e24ea1d commit bcffebf
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 16 deletions.
36 changes: 34 additions & 2 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7884,7 +7884,22 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
isSimple = false;
size = EA_SCALABLE;
attr = size;
fmt = isVectorRegister(reg1) ? IF_SVE_IE_2A : IF_SVE_ID_2A;
if (isPredicateRegister(reg1))
{
assert(offs == 0);
// For predicate, generate based off rsGetRsvdReg()
regNumber rsvdReg = codeGen->rsGetRsvdReg();

// add rsvd, fp, #imm
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
// str p0, [rsvd, #0, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);

return;
}

assert(isVectorRegister(reg1));
fmt = IF_SVE_IE_2A;

// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
Expand Down Expand Up @@ -8138,7 +8153,24 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
isSimple = false;
size = EA_SCALABLE;
attr = size;
fmt = isVectorRegister(reg1) ? IF_SVE_JH_2A : IF_SVE_JG_2A;

if (isPredicateRegister(reg1))
{
assert(offs == 0);

// For predicate, generate based off rsGetRsvdReg()
regNumber rsvdReg = codeGen->rsGetRsvdReg();

// add rsvd, fp, #imm
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
// str p0, [rsvd, #0, mul vl]
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);

return;
}

assert(isVectorRegister(reg1));
fmt = IF_SVE_JH_2A;

// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
Expand Down
5 changes: 5 additions & 0 deletions src/coreclr/jit/emitarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,11 @@ inline static bool isHighPredicateRegister(regNumber reg)
return (reg >= REG_PREDICATE_HIGH_FIRST) && (reg <= REG_PREDICATE_HIGH_LAST);
}

inline static bool isMaskReg(regNumber reg)
{
return isPredicateRegister(reg);
}

inline static bool isEvenRegister(regNumber reg)
{
if (isGeneralRegister(reg))
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/lclvars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5593,7 +5593,7 @@ unsigned Compiler::lvaGetMaxSpillTempSize()
* Doing this all in one pass is 'hard'. So instead we do it in 2 basic passes:
* 1. Assign all the offsets relative to the Virtual '0'. Offsets above (the
* incoming arguments) are positive. Offsets below (everything else) are
* negative. This pass also calcuates the total frame size (between Caller's
* negative. This pass also calculates the total frame size (between Caller's
* SP/return address and the Ambient SP).
* 2. Figure out where to place the frame pointer, and then adjust the offsets
* as needed for the final stack size and whether the offset is frame pointer
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,13 @@ class RegRecord : public Referenceable
{
registerType = FloatRegisterType;
}
#if defined(TARGET_XARCH) && defined(FEATURE_SIMD)
#if defined(FEATURE_MASKED_HW_INTRINSICS)
else
{
assert(emitter::isMaskReg(reg));
registerType = MaskRegisterType;
}
#endif
#endif // FEATURE_MASKED_HW_INTRINSICS
regNum = reg;
isCalleeSave = ((RBM_CALLEE_SAVED & genRegMask(reg)) != 0);
}
Expand Down
7 changes: 5 additions & 2 deletions src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call)

#else
killMask.RemoveRegsetForType(RBM_FLT_CALLEE_TRASH.GetFloatRegSet(), FloatRegisterType);
#if defined(TARGET_ARM64)
killMask.RemoveRegsetForType(RBM_MSK_CALLEE_TRASH.GetPredicateRegSet(), MaskRegisterType);
#endif // TARGET_ARM64
#endif // TARGET_XARCH
}
#ifdef TARGET_ARM
Expand Down Expand Up @@ -1148,8 +1151,8 @@ bool LinearScan::buildKillPositionsForNode(GenTree* tree, LsraLocation currentLo
{
continue;
}
Interval* interval = getIntervalForLocalVar(varIndex);
const bool isCallKill = ((killMask == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
Interval* interval = getIntervalForLocalVar(varIndex);
const bool isCallKill = ((killMask.getLow() == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
SingleTypeRegSet regsKillMask = killMask.GetRegSetForType(interval->registerType);

if (isCallKill)
Expand Down
17 changes: 8 additions & 9 deletions src/coreclr/jit/targetarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,15 @@
#define RBM_FLT_CALLEE_SAVED (RBM_V8|RBM_V9|RBM_V10|RBM_V11|RBM_V12|RBM_V13|RBM_V14|RBM_V15)
#define RBM_FLT_CALLEE_TRASH (RBM_V0|RBM_V1|RBM_V2|RBM_V3|RBM_V4|RBM_V5|RBM_V6|RBM_V7|RBM_V16|RBM_V17|RBM_V18|RBM_V19|RBM_V20|RBM_V21|RBM_V22|RBM_V23|RBM_V24|RBM_V25|RBM_V26|RBM_V27|RBM_V28|RBM_V29|RBM_V30|RBM_V31)

#define RBM_LOWMASK (RBM_P0|RBM_P1|RBM_P2|RBM_P3|RBM_P4|RBM_P5|RBM_P6|RBM_P7)
#define RBM_HIGHMASK (RBM_P8|RBM_P9|RBM_P10| RBM_P11|RBM_P12|RBM_P13|RBM_P14|RBM_P15)
#define RBM_ALLMASK (RBM_LOWMASK|RBM_HIGHMASK)

#define RBM_MSK_CALLEE_SAVED (0)
#define RBM_MSK_CALLEE_TRASH RBM_ALLMASK

#define RBM_CALLEE_SAVED (RBM_INT_CALLEE_SAVED | RBM_FLT_CALLEE_SAVED)
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH)
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH)

#define REG_DEFAULT_HELPER_CALL_TARGET REG_R12
#define RBM_DEFAULT_HELPER_CALL_TARGET RBM_R12
Expand Down Expand Up @@ -146,14 +153,6 @@
#define REG_JUMP_THUNK_PARAM REG_R12
#define RBM_JUMP_THUNK_PARAM RBM_R12

#define RBM_LOWMASK (RBM_P0 | RBM_P1 | RBM_P2 | RBM_P3 | RBM_P4 | RBM_P5 | RBM_P6 | RBM_P7)
#define RBM_HIGHMASK (RBM_P8 | RBM_P9 | RBM_P10 | RBM_P11 | RBM_P12 | RBM_P13 | RBM_P14 | RBM_P15)
#define RBM_ALLMASK (RBM_LOWMASK | RBM_HIGHMASK)

// TODO-SVE: Fix when adding predicate register allocation
#define RBM_MSK_CALLEE_SAVED (0)
#define RBM_MSK_CALLEE_TRASH (0)

// ARM64 write barrier ABI (see vm\arm64\asmhelpers.asm, vm\arm64\asmhelpers.S):
// CORINFO_HELP_ASSIGN_REF (JIT_WriteBarrier), CORINFO_HELP_CHECKED_ASSIGN_REF (JIT_CheckedWriteBarrier):
// On entry:
Expand Down

0 comments on commit bcffebf

Please sign in to comment.