Skip to content

Commit dd24e9c

Browse files
kunalspathaka74nh
andauthored
Arm64/Sve: Predicated Abs, Predicated/UnPredicated Add, Conditional Select (#100743)
* JIT ARM64-SVE: Add Sve.Abs() and Sve.Add() Change-Id: Ie8cfe828595da9a87adbc0857c0c44c0ce12f5b2 * Fix sve scaling in enitIns_R_S/S_R * Revert "Fix sve scaling in enitIns_R_S/S_R" This reverts commit e9fa735. * Fix sve scaling in enitIns_R_S/S_R * Restore testing * Use NaturalScale_helper for vector load/stores * wip * Add ConditionalSelect() APIs * Handle ConditionalSelect in JIT * Add test coverage * Update the test cases * jit format * fix merge conflicts * Make predicated/unpredicated work with ConditionalSelect Still some handling around RMW is needed, but this basically works * Misc. changes * jit format * jit format * Handle all the conditions correctly * jit format * fix some spacing * Removed the assert * fix the largest vector size to 64 to fix #100366 * review feedback * wip * Add SVE feature detection for Windows * fix the check for invalid alignment * Revert "Add SVE feature detection for Windows" This reverts commit ed7c781. * Handle case where Abs() is wrapped in another conditionalSelect * jit format * fix the size comparison * HW_Flag_MaskedPredicatedOnlyOperation * Revert the change in emitarm64.cpp around INS_sve_ldr_mask/INS_sve_str_mask * Fix the condition for lowering * address review feedback for movprfx * Move the special handling of Vector<>.Zero from lowerer to importer * Rename IsEmbeddedMaskedOperation/IsOptionalEmbeddedMaskedOperation * Add more test coverage for conditionalSelect * Rename test method name * Add more test coverage for conditionalSelect:Abs * jit format * Add logging on test methods * Add the missing movprfx for abs * Add few more scenarios where falseVal is zero * Make sure LoadVector is marked as explicit needing mask * revisit the codegen logic * Remove commented code and add some other comments * jit format --------- Co-authored-by: Alan Hayward <alan.hayward@arm.com>
1 parent 1a93a9d commit dd24e9c

24 files changed

+1994
-74
lines changed

src/coreclr/jit/codegenlinear.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -1648,15 +1648,13 @@ void CodeGen::genConsumeRegs(GenTree* tree)
16481648
// Update the life of the lcl var.
16491649
genUpdateLife(tree);
16501650
}
1651-
#ifdef TARGET_XARCH
16521651
#ifdef FEATURE_HW_INTRINSICS
16531652
else if (tree->OperIs(GT_HWINTRINSIC))
16541653
{
16551654
GenTreeHWIntrinsic* hwintrinsic = tree->AsHWIntrinsic();
16561655
genConsumeMultiOpOperands(hwintrinsic);
16571656
}
16581657
#endif // FEATURE_HW_INTRINSICS
1659-
#endif // TARGET_XARCH
16601658
else if (tree->OperIs(GT_BITCAST, GT_NEG, GT_CAST, GT_LSH, GT_RSH, GT_RSZ, GT_ROR, GT_BSWAP, GT_BSWAP16))
16611659
{
16621660
genConsumeRegs(tree->gtGetOp1());

src/coreclr/jit/compiler.h

+1
Original file line numberDiff line numberDiff line change
@@ -3477,6 +3477,7 @@ class Compiler
34773477
#if defined(TARGET_ARM64)
34783478
GenTree* gtNewSimdConvertVectorToMaskNode(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
34793479
GenTree* gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, var_types type);
3480+
GenTree* gtNewSimdAllTrueMaskNode(CorInfoType simdBaseJitType, unsigned simdSize);
34803481
#endif
34813482

34823483
//------------------------------------------------------------------------

src/coreclr/jit/emitloongarch64.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ enum EmitCallType
333333

334334
EC_FUNC_TOKEN, // Direct call to a helper/static/nonvirtual/global method
335335
// EC_FUNC_TOKEN_INDIR, // Indirect call to a helper/static/nonvirtual/global method
336-
// EC_FUNC_ADDR, // Direct call to an absolute address
336+
// EC_FUNC_ADDR, // Direct call to an absolute address
337337

338338
EC_INDIR_R, // Indirect call via register
339339

src/coreclr/jit/emitriscv64.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ enum EmitCallType
310310

311311
EC_FUNC_TOKEN, // Direct call to a helper/static/nonvirtual/global method
312312
// EC_FUNC_TOKEN_INDIR, // Indirect call to a helper/static/nonvirtual/global method
313-
// EC_FUNC_ADDR, // Direct call to an absolute address
313+
// EC_FUNC_ADDR, // Direct call to an absolute address
314314

315315
// EC_FUNC_VIRTUAL, // Call to a virtual method (using the vtable)
316316
EC_INDIR_R, // Indirect call via register

src/coreclr/jit/gentree.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -18012,7 +18012,7 @@ bool GenTree::canBeContained() const
1801218012
}
1801318013
else if (OperIsHWIntrinsic() && !isContainableHWIntrinsic())
1801418014
{
18015-
return isEvexEmbeddedMaskingCompatibleHWIntrinsic();
18015+
return isEmbeddedMaskingCompatibleHWIntrinsic();
1801618016
}
1801718017

1801818018
return true;
@@ -19909,24 +19909,26 @@ bool GenTree::isEvexCompatibleHWIntrinsic() const
1990919909
}
1991019910

1991119911
//------------------------------------------------------------------------
19912-
// isEvexEmbeddedMaskingCompatibleHWIntrinsic: Checks if the intrinsic is compatible
19912+
// isEmbeddedMaskingCompatibleHWIntrinsic : Checks if the intrinsic is compatible
1991319913
// with the EVEX embedded masking form for its intended lowering instruction.
1991419914
//
1991519915
// Return Value:
1991619916
// true if the intrisic node lowering instruction has an EVEX embedded masking
1991719917
//
19918-
bool GenTree::isEvexEmbeddedMaskingCompatibleHWIntrinsic() const
19918+
bool GenTree::isEmbeddedMaskingCompatibleHWIntrinsic() const
1991919919
{
19920-
#if defined(TARGET_XARCH)
1992119920
if (OperIsHWIntrinsic())
1992219921
{
19922+
#if defined(TARGET_XARCH)
1992319923
// TODO-AVX512F-CQ: Expand this to the full set of APIs and make it table driven
1992419924
// using IsEmbMaskingCompatible. For now, however, limit it to some explicit ids
1992519925
// for prototyping purposes.
1992619926
return (AsHWIntrinsic()->GetHWIntrinsicId() == NI_AVX512F_Add);
19927+
#elif defined(TARGET_ARM64)
19928+
return HWIntrinsicInfo::IsEmbeddedMaskedOperation(AsHWIntrinsic()->GetHWIntrinsicId()) ||
19929+
HWIntrinsicInfo::IsOptionalEmbeddedMaskedOperation(AsHWIntrinsic()->GetHWIntrinsicId());
19930+
#endif
1992719931
}
19928-
#endif // TARGET_XARCH
19929-
1993019932
return false;
1993119933
}
1993219934

src/coreclr/jit/gentree.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -557,9 +557,9 @@ enum GenTreeFlags : unsigned int
557557

558558
GTF_MDARRLOWERBOUND_NONFAULTING = 0x20000000, // GT_MDARR_LOWER_BOUND -- An MD array lower bound operation that cannot fault. Same as GT_IND_NONFAULTING.
559559

560-
#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
560+
#ifdef FEATURE_HW_INTRINSICS
561561
GTF_HW_EM_OP = 0x10000000, // GT_HWINTRINSIC -- node is used as an operand to an embedded mask
562-
#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS
562+
#endif // FEATURE_HW_INTRINSICS
563563
};
564564

565565
inline constexpr GenTreeFlags operator ~(GenTreeFlags a)
@@ -1465,7 +1465,7 @@ struct GenTree
14651465
bool isContainableHWIntrinsic() const;
14661466
bool isRMWHWIntrinsic(Compiler* comp);
14671467
bool isEvexCompatibleHWIntrinsic() const;
1468-
bool isEvexEmbeddedMaskingCompatibleHWIntrinsic() const;
1468+
bool isEmbeddedMaskingCompatibleHWIntrinsic() const;
14691469
#else
14701470
bool isCommutativeHWIntrinsic() const
14711471
{
@@ -1487,7 +1487,7 @@ struct GenTree
14871487
return false;
14881488
}
14891489

1490-
bool isEvexEmbeddedMaskingCompatibleHWIntrinsic() const
1490+
bool isEmbeddedMaskingCompatibleHWIntrinsic() const
14911491
{
14921492
return false;
14931493
}
@@ -2226,7 +2226,7 @@ struct GenTree
22262226
gtFlags &= ~GTF_ICON_HDL_MASK;
22272227
}
22282228

2229-
#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
2229+
#ifdef FEATURE_HW_INTRINSICS
22302230

22312231
bool IsEmbMaskOp()
22322232
{
@@ -2240,7 +2240,7 @@ struct GenTree
22402240
gtFlags |= GTF_HW_EM_OP;
22412241
}
22422242

2243-
#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS
2243+
#endif // FEATURE_HW_INTRINSICS
22442244

22452245
static bool HandleKindDataIsInvariant(GenTreeFlags flags);
22462246

src/coreclr/jit/hwintrinsic.cpp

+43-17
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,36 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
13961396
GenTree* op3 = nullptr;
13971397
GenTree* op4 = nullptr;
13981398

1399+
switch (numArgs)
1400+
{
1401+
case 4:
1402+
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
1403+
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
1404+
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
1405+
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
1406+
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1407+
break;
1408+
1409+
case 3:
1410+
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
1411+
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
1412+
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1413+
break;
1414+
1415+
case 2:
1416+
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
1417+
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
1418+
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1419+
break;
1420+
1421+
case 1:
1422+
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1423+
break;
1424+
1425+
default:
1426+
break;
1427+
}
1428+
13991429
switch (numArgs)
14001430
{
14011431
case 0:
@@ -1407,8 +1437,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
14071437

14081438
case 1:
14091439
{
1410-
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1411-
14121440
if ((category == HW_Category_MemoryLoad) && op1->OperIs(GT_CAST))
14131441
{
14141442
// Although the API specifies a pointer, if what we have is a BYREF, that's what
@@ -1467,10 +1495,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
14671495

14681496
case 2:
14691497
{
1470-
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
1471-
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
1472-
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1473-
14741498
retNode = isScalar
14751499
? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, intrinsic)
14761500
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, intrinsic, simdBaseJitType, simdSize);
@@ -1524,10 +1548,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
15241548

15251549
case 3:
15261550
{
1527-
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
1528-
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
1529-
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1530-
15311551
#ifdef TARGET_ARM64
15321552
if (intrinsic == NI_AdvSimd_LoadAndInsertScalar)
15331553
{
@@ -1569,12 +1589,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
15691589

15701590
case 4:
15711591
{
1572-
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
1573-
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
1574-
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
1575-
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
1576-
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
1577-
15781592
assert(!isScalar);
15791593
retNode =
15801594
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
@@ -1591,10 +1605,22 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
15911605
}
15921606

15931607
#if defined(TARGET_ARM64)
1594-
if (HWIntrinsicInfo::IsMaskedOperation(intrinsic))
1608+
if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrinsic))
15951609
{
15961610
assert(numArgs > 0);
15971611
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
1612+
if (intrinsic == NI_Sve_ConditionalSelect)
1613+
{
1614+
if (op1->IsVectorAllBitsSet())
1615+
{
1616+
return retNode->AsHWIntrinsic()->Op(2);
1617+
}
1618+
else if (op1->IsVectorZero())
1619+
{
1620+
return retNode->AsHWIntrinsic()->Op(3);
1621+
}
1622+
}
1623+
15981624
if (!varTypeIsMask(op1))
15991625
{
16001626
// Op1 input is a vector. HWInstrinsic requires a mask.

src/coreclr/jit/hwintrinsic.h

+27-2
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,18 @@ enum HWIntrinsicFlag : unsigned int
186186
HW_Flag_ReturnsPerElementMask = 0x10000,
187187

188188
// The intrinsic uses a mask in arg1 to select elements present in the result
189-
HW_Flag_MaskedOperation = 0x20000,
189+
HW_Flag_ExplicitMaskedOperation = 0x20000,
190190

191191
// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.
192192
HW_Flag_LowMaskedOperation = 0x40000,
193193

194+
// The intrinsic can optionally use a mask in arg1 to select elements present in the result, which is not present in
195+
// the API call
196+
HW_Flag_OptionalEmbeddedMaskedOperation = 0x80000,
197+
198+
// The intrinsic uses a mask in arg1 to select elements present in the result, which is not present in the API call
199+
HW_Flag_EmbeddedMaskedOperation = 0x100000,
200+
194201
#else
195202
#error Unsupported platform
196203
#endif
@@ -872,7 +879,7 @@ struct HWIntrinsicInfo
872879
static bool IsMaskedOperation(NamedIntrinsic id)
873880
{
874881
const HWIntrinsicFlag flags = lookupFlags(id);
875-
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id);
882+
return IsLowMaskedOperation(id) || IsOptionalEmbeddedMaskedOperation(id) || IsExplicitMaskedOperation(id);
876883
}
877884

878885
static bool IsLowMaskedOperation(NamedIntrinsic id)
@@ -881,6 +888,24 @@ struct HWIntrinsicInfo
881888
return (flags & HW_Flag_LowMaskedOperation) != 0;
882889
}
883890

891+
static bool IsOptionalEmbeddedMaskedOperation(NamedIntrinsic id)
892+
{
893+
const HWIntrinsicFlag flags = lookupFlags(id);
894+
return (flags & HW_Flag_OptionalEmbeddedMaskedOperation) != 0;
895+
}
896+
897+
static bool IsEmbeddedMaskedOperation(NamedIntrinsic id)
898+
{
899+
const HWIntrinsicFlag flags = lookupFlags(id);
900+
return (flags & HW_Flag_EmbeddedMaskedOperation) != 0;
901+
}
902+
903+
static bool IsExplicitMaskedOperation(NamedIntrinsic id)
904+
{
905+
const HWIntrinsicFlag flags = lookupFlags(id);
906+
return (flags & HW_Flag_ExplicitMaskedOperation) != 0;
907+
}
908+
884909
#endif // TARGET_ARM64
885910

886911
static bool HasSpecialSideEffect(NamedIntrinsic id)

src/coreclr/jit/hwintrinsicarm64.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -2222,9 +2222,8 @@ GenTree* Compiler::gtNewSimdConvertVectorToMaskNode(var_types type,
22222222
assert(varTypeIsSIMD(node));
22232223

22242224
// ConvertVectorToMask uses cmpne which requires an embedded mask.
2225-
GenTree* embeddedMask = gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
2226-
return gtNewSimdHWIntrinsicNode(TYP_MASK, embeddedMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType,
2227-
simdSize);
2225+
GenTree* trueMask = gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
2226+
return gtNewSimdHWIntrinsicNode(TYP_MASK, trueMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType, simdSize);
22282227
}
22292228

22302229
//------------------------------------------------------------------------
@@ -2246,4 +2245,19 @@ GenTree* Compiler::gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, va
22462245
node->GetSimdSize());
22472246
}
22482247

2248+
//------------------------------------------------------------------------
2249+
// gtNewSimdEmbeddedMaskNode: Create an embedded mask
2250+
//
2251+
// Arguments:
2252+
// simdBaseJitType -- the base jit type of the nodes being masked
2253+
// simdSize -- the simd size of the nodes being masked
2254+
//
2255+
// Return Value:
2256+
// The mask
2257+
//
2258+
GenTree* Compiler::gtNewSimdAllTrueMaskNode(CorInfoType simdBaseJitType, unsigned simdSize)
2259+
{
2260+
return gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
2261+
}
2262+
22492263
#endif // FEATURE_HW_INTRINSICS

0 commit comments

Comments
 (0)