Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8a0c9a3

Browse files
committedJul 25, 2023
Respond to PR feedback and try to reduce TP regression more
1 parent 136e898 commit 8a0c9a3

File tree

10 files changed

+200
-24
lines changed

10 files changed

+200
-24
lines changed
 

‎src/coreclr/jit/compiler.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -3386,9 +3386,19 @@ void Compiler::compInitOptions(JitFlags* jitFlags)
33863386
{
33873387
rbmAllMask |= RBM_ALLMASK_EVEX;
33883388
rbmMskCalleeTrash |= RBM_MSK_CALLEE_TRASH_EVEX;
3389-
cntCalleeTrashMask += CNT_CALLEE_TRASH_MASK;
3389+
cntCalleeTrashMask += CNT_CALLEE_TRASH_MASK_EVEX;
33903390
}
33913391

3392+
// Make sure we copy the register info and initialize the
3393+
// trash regs after the underlying fields are initialized
3394+
3395+
const regMaskTP vtCalleeTrashRegs[TYP_COUNT]{
3396+
#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) ctr,
3397+
#include "typelist.h"
3398+
#undef DEF_TP
3399+
};
3400+
memcpy(varTypeCalleeTrashRegs, vtCalleeTrashRegs, sizeof(regMaskTP) * TYP_COUNT);
3401+
33923402
codeGen->CopyRegisterInfo();
33933403
#endif // TARGET_XARCH
33943404
}

‎src/coreclr/jit/compiler.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -10899,15 +10899,15 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
1089910899
unsigned cntCalleeTrashFloat;
1090010900

1090110901
public:
10902-
regMaskTP get_RBM_ALLFLOAT() const
10902+
FORCEINLINE regMaskTP get_RBM_ALLFLOAT() const
1090310903
{
1090410904
return this->rbmAllFloat;
1090510905
}
10906-
regMaskTP get_RBM_FLT_CALLEE_TRASH() const
10906+
FORCEINLINE regMaskTP get_RBM_FLT_CALLEE_TRASH() const
1090710907
{
1090810908
return this->rbmFltCalleeTrash;
1090910909
}
10910-
unsigned get_CNT_CALLEE_TRASH_FLOAT() const
10910+
FORCEINLINE unsigned get_CNT_CALLEE_TRASH_FLOAT() const
1091110911
{
1091210912
return this->cntCalleeTrashFloat;
1091310913
}
@@ -10935,17 +10935,18 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
1093510935
regMaskTP rbmAllMask;
1093610936
regMaskTP rbmMskCalleeTrash;
1093710937
unsigned cntCalleeTrashMask;
10938+
regMaskTP varTypeCalleeTrashRegs[TYP_COUNT];
1093810939

1093910940
public:
10940-
regMaskTP get_RBM_ALLMASK() const
10941+
FORCEINLINE regMaskTP get_RBM_ALLMASK() const
1094110942
{
1094210943
return this->rbmAllMask;
1094310944
}
10944-
regMaskTP get_RBM_MSK_CALLEE_TRASH() const
10945+
FORCEINLINE regMaskTP get_RBM_MSK_CALLEE_TRASH() const
1094510946
{
1094610947
return this->rbmMskCalleeTrash;
1094710948
}
10948-
unsigned get_CNT_CALLEE_TRASH_MASK() const
10949+
FORCEINLINE unsigned get_CNT_CALLEE_TRASH_MASK() const
1094910950
{
1095010951
return this->cntCalleeTrashMask;
1095110952
}

‎src/coreclr/jit/hwintrinsiccodegenxarch.cpp

+111
Original file line numberDiff line numberDiff line change
@@ -2136,6 +2136,39 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node)
21362136
break;
21372137
}
21382138

2139+
case NI_AVX512F_NotMask:
2140+
{
2141+
uint32_t simdSize = node->GetSimdSize();
2142+
uint32_t count = simdSize / genTypeSize(baseType);
2143+
2144+
if (count <= 8)
2145+
{
2146+
assert((count == 2) || (count == 4) || (count == 8));
2147+
ins = INS_knotb;
2148+
}
2149+
else if (count == 16)
2150+
{
2151+
ins = INS_knotw;
2152+
}
2153+
else if (count == 32)
2154+
{
2155+
ins = INS_knotd;
2156+
}
2157+
else
2158+
{
2159+
assert(count == 64);
2160+
ins = INS_knotq;
2161+
}
2162+
2163+
op1Reg = op1->GetRegNum();
2164+
2165+
assert(emitter::isMaskReg(targetReg));
2166+
assert(emitter::isMaskReg(op1Reg));
2167+
2168+
emit->emitIns_R_R(ins, EA_8BYTE, targetReg, op1Reg);
2169+
break;
2170+
}
2171+
21392172
case NI_AVX512F_OrMask:
21402173
{
21412174
uint32_t simdSize = node->GetSimdSize();
@@ -2174,6 +2207,84 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node)
21742207
break;
21752208
}
21762209

2210+
case NI_AVX512F_ShiftLeftMask:
2211+
{
2212+
uint32_t simdSize = node->GetSimdSize();
2213+
uint32_t count = simdSize / genTypeSize(baseType);
2214+
2215+
if (count <= 8)
2216+
{
2217+
assert((count == 2) || (count == 4) || (count == 8));
2218+
ins = INS_kshiftlb;
2219+
}
2220+
else if (count == 16)
2221+
{
2222+
ins = INS_kshiftlw;
2223+
}
2224+
else if (count == 32)
2225+
{
2226+
ins = INS_kshiftld;
2227+
}
2228+
else
2229+
{
2230+
assert(count == 64);
2231+
ins = INS_kshiftlq;
2232+
}
2233+
2234+
op1Reg = op1->GetRegNum();
2235+
2236+
GenTree* op2 = node->Op(2);
2237+
assert(op2->IsCnsIntOrI() && op2->isContained());
2238+
2239+
assert(emitter::isMaskReg(targetReg));
2240+
assert(emitter::isMaskReg(op1Reg));
2241+
2242+
ssize_t ival = op2->AsIntCon()->IconValue();
2243+
assert((ival >= 0) && (ival <= 255));
2244+
2245+
emit->emitIns_R_R_I(ins, EA_8BYTE, targetReg, op1Reg, (int8_t)ival);
2246+
break;
2247+
}
2248+
2249+
case NI_AVX512F_ShiftRightMask:
2250+
{
2251+
uint32_t simdSize = node->GetSimdSize();
2252+
uint32_t count = simdSize / genTypeSize(baseType);
2253+
2254+
if (count <= 8)
2255+
{
2256+
assert((count == 2) || (count == 4) || (count == 8));
2257+
ins = INS_kshiftrb;
2258+
}
2259+
else if (count == 16)
2260+
{
2261+
ins = INS_kshiftrw;
2262+
}
2263+
else if (count == 32)
2264+
{
2265+
ins = INS_kshiftrd;
2266+
}
2267+
else
2268+
{
2269+
assert(count == 64);
2270+
ins = INS_kshiftrq;
2271+
}
2272+
2273+
op1Reg = op1->GetRegNum();
2274+
2275+
GenTree* op2 = node->Op(2);
2276+
assert(op2->IsCnsIntOrI() && op2->isContained());
2277+
2278+
assert(emitter::isMaskReg(targetReg));
2279+
assert(emitter::isMaskReg(op1Reg));
2280+
2281+
ssize_t ival = op2->AsIntCon()->IconValue();
2282+
assert((ival >= 0) && (ival <= 255));
2283+
2284+
emit->emitIns_R_R_I(ins, EA_8BYTE, targetReg, op1Reg, (int8_t)ival);
2285+
break;
2286+
}
2287+
21772288
case NI_AVX512F_XorMask:
21782289
{
21792290
uint32_t simdSize = node->GetSimdSize();

‎src/coreclr/jit/hwintrinsiclistxarch.h

+2
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,8 @@ HARDWARE_INTRINSIC(AVX512F, NotMask,
13341334
HARDWARE_INTRINSIC(AVX512F, op_EqualityMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative)
13351335
HARDWARE_INTRINSIC(AVX512F, op_InequalityMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative)
13361336
HARDWARE_INTRINSIC(AVX512F, OrMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative|HW_Flag_ReturnsPerElementMask)
1337+
HARDWARE_INTRINSIC(AVX512F, ShiftLeftMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM|HW_Flag_SpecialCodeGen)
1338+
HARDWARE_INTRINSIC(AVX512F, ShiftRightMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM|HW_Flag_SpecialCodeGen)
13371339
HARDWARE_INTRINSIC(AVX512F, XorMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative|HW_Flag_ReturnsPerElementMask)
13381340

13391341
#endif // FEATURE_HW_INTRINSIC

‎src/coreclr/jit/lowerxarch.cpp

+29-2
Original file line numberDiff line numberDiff line change
@@ -2002,10 +2002,37 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
20022002

20032003
default:
20042004
{
2005-
maskIntrinsicId = NI_AVX512F_NotMask;
2006-
maskNode = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, maskIntrinsicId,
2005+
// We don't have a well known intrinsic, so we need to inverse the mask keeping the upper
2006+
// n-bits clear. If we have 1 element, then the upper 7-bits need to be cleared. If we have
2007+
// 2, then the upper 6-bits, and if we have 4, then the upper 4-bits.
2008+
//
2009+
// There isn't necessarily a trivial way to do this outside not, shift-left by n,
2010+
// shift-right by n. This preserves count bits, while clearing the upper n-bits
2011+
2012+
GenTree* cnsNode;
2013+
2014+
maskNode = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, NI_AVX512F_NotMask,
20072015
simdBaseJitType, simdSize);
20082016
BlockRange().InsertBefore(node, maskNode);
2017+
2018+
cnsNode = comp->gtNewIconNode(8 - count);
2019+
BlockRange().InsertAfter(maskNode, cnsNode);
2020+
2021+
maskNode =
2022+
comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, cnsNode, NI_AVX512F_ShiftLeftMask,
2023+
simdBaseJitType, simdSize);
2024+
BlockRange().InsertAfter(cnsNode, maskNode);
2025+
LowerNode(maskNode);
2026+
2027+
cnsNode = comp->gtNewIconNode(8 - count);
2028+
BlockRange().InsertAfter(maskNode, cnsNode);
2029+
2030+
maskNode =
2031+
comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, cnsNode, NI_AVX512F_ShiftRightMask,
2032+
simdBaseJitType, simdSize);
2033+
BlockRange().InsertAfter(cnsNode, maskNode);
2034+
2035+
maskIntrinsicId = NI_AVX512F_ShiftRightMask;
20092036
break;
20102037
}
20112038
}

‎src/coreclr/jit/lsra.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -716,11 +716,11 @@ LinearScan::LinearScan(Compiler* theCompiler)
716716
#if defined(TARGET_XARCH)
717717
rbmAllMask = compiler->rbmAllMask;
718718
rbmMskCalleeTrash = compiler->rbmMskCalleeTrash;
719+
memcpy(varTypeCalleeTrashRegs, compiler->varTypeCalleeTrashRegs, sizeof(regMaskTP) * TYP_COUNT);
719720

720721
if (!compiler->canUseEvexEncoding())
721722
{
722-
availableRegCount -= CNT_HIGHFLOAT;
723-
availableRegCount -= CNT_MASK_REGS;
723+
availableRegCount -= (CNT_HIGHFLOAT + CNT_MASK_REGS);
724724
}
725725
#endif // TARGET_XARCH
726726

‎src/coreclr/jit/lsra.h

+16-7
Original file line numberDiff line numberDiff line change
@@ -2027,11 +2027,11 @@ class LinearScan : public LinearScanInterface
20272027
regMaskTP rbmAllFloat;
20282028
regMaskTP rbmFltCalleeTrash;
20292029

2030-
regMaskTP get_RBM_ALLFLOAT() const
2030+
FORCEINLINE regMaskTP get_RBM_ALLFLOAT() const
20312031
{
20322032
return this->rbmAllFloat;
20332033
}
2034-
regMaskTP get_RBM_FLT_CALLEE_TRASH() const
2034+
FORCEINLINE regMaskTP get_RBM_FLT_CALLEE_TRASH() const
20352035
{
20362036
return this->rbmFltCalleeTrash;
20372037
}
@@ -2041,19 +2041,19 @@ class LinearScan : public LinearScanInterface
20412041
regMaskTP rbmAllMask;
20422042
regMaskTP rbmMskCalleeTrash;
20432043

2044-
regMaskTP get_RBM_ALLMASK() const
2044+
FORCEINLINE regMaskTP get_RBM_ALLMASK() const
20452045
{
20462046
return this->rbmAllMask;
20472047
}
2048-
regMaskTP get_RBM_MSK_CALLEE_TRASH() const
2048+
FORCEINLINE regMaskTP get_RBM_MSK_CALLEE_TRASH() const
20492049
{
20502050
return this->rbmMskCalleeTrash;
20512051
}
20522052
#endif // TARGET_XARCH
20532053

20542054
unsigned availableRegCount;
20552055

2056-
unsigned get_AVAILABLE_REG_COUNT() const
2056+
FORCEINLINE unsigned get_AVAILABLE_REG_COUNT() const
20572057
{
20582058
return this->availableRegCount;
20592059
}
@@ -2064,7 +2064,7 @@ class LinearScan : public LinearScanInterface
20642064
// NOTE: we currently don't need a LinearScan `this` pointer for this definition, and some callers
20652065
// don't have one available, so make is static.
20662066
//
2067-
static regMaskTP calleeSaveRegs(RegisterType rt)
2067+
static FORCEINLINE regMaskTP calleeSaveRegs(RegisterType rt)
20682068
{
20692069
static const regMaskTP varTypeCalleeSaveRegs[] = {
20702070
#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) csr,
@@ -2076,16 +2076,25 @@ class LinearScan : public LinearScanInterface
20762076
return varTypeCalleeSaveRegs[rt];
20772077
}
20782078

2079+
#if defined(TARGET_XARCH)
2080+
// Not all of the callee trash values are constant, so don't declare this as a method local static
2081+
// doing so results in significantly more complex codegen and we'd rather just initialize this once
2082+
// as part of initializing LSRA instead
2083+
regMaskTP varTypeCalleeTrashRegs[TYP_COUNT];
2084+
#endif // TARGET_XARCH
2085+
20792086
//------------------------------------------------------------------------
20802087
// callerSaveRegs: Get the set of caller-save registers of the given RegisterType
20812088
//
2082-
regMaskTP callerSaveRegs(RegisterType rt) const
2089+
FORCEINLINE regMaskTP callerSaveRegs(RegisterType rt) const
20832090
{
2091+
#if !defined(TARGET_XARCH)
20842092
static const regMaskTP varTypeCalleeTrashRegs[] = {
20852093
#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) ctr,
20862094
#include "typelist.h"
20872095
#undef DEF_TP
20882096
};
2097+
#endif // !TARGET_XARCH
20892098

20902099
assert((unsigned)rt < ArrLen(varTypeCalleeTrashRegs));
20912100
return varTypeCalleeTrashRegs[rt];

‎src/coreclr/jit/lsrabuild.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,10 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call)
891891
// if there is no FP used, we can ignore the FP kills
892892
if (!compiler->compFloatingPointUsed)
893893
{
894-
killMask &= ~RBM_FLT_CALLEE_TRASH;
895-
896894
#if defined(TARGET_XARCH)
897-
killMask &= ~RBM_MSK_CALLEE_TRASH;
895+
killMask &= ~(RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH);
896+
#else
897+
killMask &= ~RBM_FLT_CALLEE_TRASH;
898898
#endif // TARGET_XARCH
899899
}
900900
#ifdef TARGET_ARM

‎src/coreclr/jit/vartype.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,19 @@ inline bool varTypeUsesFloatReg(T vt)
328328
template <class T>
329329
inline bool varTypeUsesMaskReg(T vt)
330330
{
331-
return varTypeRegister[TypeGet(vt)] == VTR_MASK;
331+
// The technically correct check is:
332+
// return varTypeRegister[TypeGet(vt)] == VTR_MASK;
333+
//
334+
// However, we only have one type that uses VTR_MASK today
335+
// and so its quite a bit cheaper to just check that directly
336+
337+
#if defined(FEATURE_SIMD) && defined(TARGET_XARCH)
338+
assert((TypeGet(vt) == TYP_MASK) || (varTypeRegister[TypeGet(vt)] != VTR_MASK));
339+
return TypeGet(vt) == TYP_MASK;
340+
#else
341+
assert(varTypeRegister[TypeGet(vt)] != VTR_MASK);
342+
return false;
343+
#endif
332344
}
333345

334346
template <class T>

‎src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,9 @@ public static unsafe int IndexOfNullCharacter(char* searchSpace)
647647

648648
Vector512<ushort> search = *(Vector512<ushort>*)(searchSpace + (nuint)offset);
649649

650-
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
651-
// So the bit position in 'matches' corresponds to the element offset.
650+
// AVX-512 returns comparison results in a mask register, so we want to optimize
651+
// the core check to simply be an "none match" check. This will slightly increase
652+
// the cost for the early match case, but greatly improves perf otherwise.
652653
if (!Vector512.EqualsAny(search, Vector512<ushort>.Zero))
653654
{
654655
// Zero flags set so no matches
@@ -657,6 +658,9 @@ public static unsafe int IndexOfNullCharacter(char* searchSpace)
657658
continue;
658659
}
659660

661+
// Note that ExtractMostSignificantBits has converted the equal vector elements into a set of bit flags,
662+
// So the bit position in 'matches' corresponds to the element offset.
663+
//
660664
// Find bitflag offset of first match and add to current offset,
661665
// flags are in bytes so divide for chars
662666
ulong matches = Vector512.Equals(search, Vector512<ushort>.Zero).ExtractMostSignificantBits();

0 commit comments

Comments
 (0)
Please sign in to comment.