Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add TrueMask and LoadVector #98218

Merged
merged 18 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4554,6 +4554,11 @@ class Compiler
NamedIntrinsic intrinsic, GenTree* immOp, bool mustExpand, int immLowerBound, int immUpperBound);
GenTree* addRangeCheckForHWIntrinsic(GenTree* immOp, int immLowerBound, int immUpperBound);

#if defined(TARGET_ARM64)
a74nh marked this conversation as resolved.
Show resolved Hide resolved
GenTree* convertHWIntrinsicToMask(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
GenTree* convertHWIntrinsicFromMask(GenTreeHWIntrinsic* node, var_types type);
#endif

#endif // FEATURE_HW_INTRINSICS
GenTree* impArrayAccessIntrinsic(CORINFO_CLASS_HANDLE clsHnd,
CORINFO_SIG_INFO* sig,
Expand Down
34 changes: 34 additions & 0 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7303,6 +7303,34 @@ emitter::code_t emitter::emitInsCodeSve(instruction ins, insFormat fmt)
}
}

// For the given 'elemsize' returns the 'arrangement' when used in a SVE vector register arrangement.
// Asserts and returns INS_OPTS_NONE if an invalid 'elemsize' is passed
//
/*static*/ insOpts emitter::optGetSveInsOpt(emitAttr elemsize)
{
switch (elemsize)
{
case EA_1BYTE:
return INS_OPTS_SCALABLE_B;

case EA_2BYTE:
return INS_OPTS_SCALABLE_H;

case EA_4BYTE:
return INS_OPTS_SCALABLE_S;

case EA_8BYTE:
return INS_OPTS_SCALABLE_D;

case EA_16BYTE:
return INS_OPTS_SCALABLE_Q;

default:
assert(!"Invalid emitAttr for sve vector register");
return INS_OPTS_NONE;
}
}

// For the given 'arrangement' returns the 'elemsize' specified by the SVE vector register arrangement
// asserts and returns EA_UNKNOWN if an invalid 'arrangement' value is passed
//
Expand Down Expand Up @@ -13020,6 +13048,12 @@ void emitter::emitIns_R_R_R(instruction ins,
fmt = IF_SVE_HP_3A;
break;

case INS_sve_ld1b:
case INS_sve_ld1h:
case INS_sve_ld1w:
case INS_sve_ld1d:
return emitIns_R_R_R_I(ins, size, reg1, reg2, reg3, 0, opt);

default:
unreached();
break;
Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/jit/emitarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,9 @@ static emitAttr optGetDatasize(insOpts arrangement);
// For the given 'arrangement' returns the 'elemsize' specified by the vector register arrangement
static emitAttr optGetElemsize(insOpts arrangement);

// For the given 'elemsize' returns the 'arrangement' when used in a SVE vector register arrangement.
static insOpts optGetSveInsOpt(emitAttr elemsize);

// For the given 'arrangement' returns the 'elemsize' specified by the SVE vector register arrangement
static emitAttr optGetSveElemsize(insOpts arrangement);

Expand Down
5 changes: 4 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26067,9 +26067,12 @@ bool GenTreeHWIntrinsic::OperIsMemoryLoad(GenTree** pAddr) const
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x2:
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x3:
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x4:

addr = Op(3);
break;

case NI_Sve_LoadVector:
addr = Op(2);
break;
#endif // TARGET_ARM64

default:
Expand Down
49 changes: 39 additions & 10 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,15 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
compFloatingPointUsed = true;
}

var_types nodeRetType = retType;
#if defined(TARGET_ARM64)
if (HWIntrinsicInfo::ReturnsPerElementMask(intrinsic))
{
// Ensure the result is generated to a mask.
nodeRetType = TYP_MASK;
}
#endif // defined(TARGET_ARM64)

// table-driven importer of simple intrinsics
if (impIsTableDrivenHWIntrinsic(intrinsic, category))
{
Expand Down Expand Up @@ -1392,7 +1401,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case 0:
{
assert(!isScalar);
retNode = gtNewSimdHWIntrinsicNode(retType, intrinsic, simdBaseJitType, simdSize);
retNode = gtNewSimdHWIntrinsicNode(nodeRetType, intrinsic, simdBaseJitType, simdSize);
break;
}

Expand All @@ -1410,8 +1419,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
}

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, intrinsic, simdBaseJitType, simdSize);
retNode = isScalar ? gtNewScalarHWIntrinsicNode(nodeRetType, op1, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, intrinsic, simdBaseJitType, simdSize);

#if defined(TARGET_XARCH)
switch (intrinsic)
Expand Down Expand Up @@ -1462,8 +1471,9 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize);
retNode = isScalar
? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, intrinsic, simdBaseJitType, simdSize);

#ifdef TARGET_XARCH
if ((intrinsic == NI_SSE42_Crc32) || (intrinsic == NI_SSE42_X64_Crc32))
Expand Down Expand Up @@ -1543,9 +1553,9 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
op3 = addRangeCheckIfNeeded(intrinsic, op3, mustExpand, immLowerBound, immUpperBound);
}

retNode = isScalar
? gtNewScalarHWIntrinsicNode(retType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
retNode = isScalar ? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, intrinsic, simdBaseJitType,
simdSize);

#ifdef TARGET_XARCH
if ((intrinsic == NI_AVX2_GatherVector128) || (intrinsic == NI_AVX2_GatherVector256))
Expand All @@ -1566,7 +1576,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

assert(!isScalar);
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
break;
}

Expand All @@ -1576,8 +1587,26 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
else
{
retNode = impSpecialIntrinsic(intrinsic, clsHnd, method, sig, simdBaseJitType, retType, simdSize);
retNode = impSpecialIntrinsic(intrinsic, clsHnd, method, sig, simdBaseJitType, nodeRetType, simdSize);
}

#if defined(TARGET_ARM64)
if (HWIntrinsicInfo::IsMaskedOperation(intrinsic))
{
// Op1 input is a vector. HWInstrinsic requires a mask, so convert to a mask.
assert(numArgs > 0);
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
op1 = convertHWIntrinsicToMask(retType, op1, simdBaseJitType, simdSize);
retNode->AsHWIntrinsic()->Op(1) = op1;
}

if (retType != nodeRetType)
{
// HWInstrinsic returns a mask, but all returns must be vectors, so convert mask to vector.
assert(HWIntrinsicInfo::ReturnsPerElementMask(intrinsic));
retNode = convertHWIntrinsicFromMask(retNode->AsHWIntrinsic(), retType);
}
#endif // defined(TARGET_ARM64)

if ((retNode != nullptr) && retNode->OperIs(GT_HWINTRINSIC))
{
Expand Down
41 changes: 37 additions & 4 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum HWIntrinsicCategory : uint8_t
HW_Category_ShiftLeftByImmediate,
HW_Category_ShiftRightByImmediate,
HW_Category_SIMDByIndexedElement,
HW_Category_EnumPattern,

// Helper intrinsics
// - do not directly correspond to a instruction, such as Vector64.AllBitsSet
Expand Down Expand Up @@ -175,6 +176,21 @@ enum HWIntrinsicFlag : unsigned int

// The intrinsic needs consecutive registers
HW_Flag_NeedsConsecutiveRegisters = 0x4000,

// The intrinsic uses scalable registers
HW_Flag_Scalable = 0x8000,

// Returns Per-Element Mask
// the intrinsic returns a vector containing elements that are either "all bits set" or "all bits clear"
// this output can be used as a per-element mask
HW_Flag_ReturnsPerElementMask = 0x10000,
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved

// The intrinsic uses a mask in arg1 to select elements present in the result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg1: Is it always be the case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just check for TYP_MASK to determine this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg1: Is it always be the case?

Yes, that's the sve convention. Result, then mask, then inputs.

Can we not just check for TYP_MASK to determine this?

Ok, that sounds better. I can look and see how this would be done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just check for TYP_MASK to determine this?

@tannergooding - Looking closer at this, I'm not quite sure what this would entail.

In hwintrinsiclistxarch.h the only reference to mask is use of HW_Flag_ReturnsPerElementMask.

I can't see any obvious way for the jit to understand know that the first arg of the method is expected to be a predicate mask, other than to use the enum or hardcode it with case statements somewhere.

The jit can check the type of the actual arg1 child node, but that only tells us what the type actually is, and not what the expected type is. I imagine I'll have to write code that says if the actual type and expected type don't match, then somehow convert arg1 to the expected type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine I'll have to write code that says if the actual type and expected type don't match, then somehow convert arg1 to the expected type.

Yes, basically.

Most intrinsics support masking optionally and so you'll have something similar to this https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/gentree.cpp#L19988-L20008. That is, you'll have some bool GenTree::isSveEmbeddedMaskingCompatibleHWIntrinsic() which likely looks up a flag in the hwintrinsiclistarm64.h table to see if that particular intrinsic supports embedded masking/predication.

There are then a handful of intrinsics which require masking. For example, SVE comparison intrinsics may always return a TYP_MASK, in which case you could either add a new entry to the table such as HW_Flag_ReturnsSveMask or explicitly handle it like xarch does here: https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/hwintrinsicxarch.cpp#L3985-L3999

There are then a handful of intrinsics which require mask inputs and which aren't recognized via pattern matching. You would likewise add a flag or manually handle the few of them like this: https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/hwintrinsicxarch.cpp#L3970-L3983

The insertion of the ConvertVectorToMask and ConvertMaskToVector intrinsics is important since the user may have passedin something that was of the incorrect type. For example, it might've been a mask of bytes, where we needed a mask of ints; or might've been an actual vector where we needed a mask and vice-versa. Likewise it ensures we don't need to check the type on every other intrinsic that does properly take a vector.

We then make this efficient in morph (see https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/morph.cpp#L10775-L10827) where we ensure that we aren't unnecessarily converting from mask to vector and back to mask, or vice versa. This allows things that take a mask to consume a produced mask directly and gives the optimal codegen expected in most scenarios.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was the comment around

We are notably missing and need to add a bit which handles the case where we have LCL_VAR TYP_SIMD = TYP_MASK because that can currently block the ability to consume a mask directly if it's multi-use. We ideally would have it stored as LCL_VAR TYP_MASK instead (even if the use manually hoisted as a Vector in C#/IL) and then have the things consume it as ConvertMaskToVector(LCL_VAR) if they actually needed a vector.

This shouldn't be overly complex to add, however, it's just not been done as of yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. That feels like it might touch quite a few files. Given the size of this PR, do you think it's worth keeping this PR as is, and then putting the LCL_VAR TYP_MASK in a follow on, along with the lowering code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then putting the LCL_VAR TYP_MASK in a follow on

Yes, I think this would even be the preferred route given its not required and is its own isolated change really.

along with the lowering code?

Which lowering code is this?


In general I think its fine for this PR to be the basic plumbing of TYP_MASK support into the Arm64 side of the JIT. As long as TrueMask and LoadVector are minimally working as expected, I think we're golden and we can extend that to other operations and enable optimizations separately. That is exactly what we did for xarch to help with review and scoping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which lowering code is this?

I added some code do the remove the mask->vector->mask and vector->mask->vector conversions. But, nothing in this PR uses it because of the lcl var, so I decided not to push it.

Will mark this as ready now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will mark this as ready now.

... but not quite yet, as I need #99049 to merge so I can remove it from this PR.

HW_Flag_MaskedOperation = 0x20000,

// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.
HW_Flag_LowMaskedOperation = 0x40000,

#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -654,10 +670,8 @@ struct HWIntrinsicInfo
static bool ReturnsPerElementMask(NamedIntrinsic id)
{
HWIntrinsicFlag flags = lookupFlags(id);
#if defined(TARGET_XARCH)
#if defined(TARGET_XARCH) || defined(TARGET_ARM64)
return (flags & HW_Flag_ReturnsPerElementMask) != 0;
#elif defined(TARGET_ARM64)
unreached();
#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -848,6 +862,25 @@ struct HWIntrinsicInfo
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_HasImmediateOperand) != 0;
}

static bool IsScalable(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_Scalable) != 0;
}

static bool IsMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id);
}

static bool IsLowMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_LowMaskedOperation) != 0;
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down Expand Up @@ -907,7 +940,7 @@ struct HWIntrinsic final
InitializeBaseType(node);
}

bool IsTableDriven() const
bool codeGenIsTableDriven() const
{
// TODO-Arm64-Cleanup - make more categories to the table-driven framework
bool isTableDrivenCategory = category != HW_Category_Helper;
Expand Down
54 changes: 54 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ void HWIntrinsicInfo::lookupImmBounds(
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
break;

case NI_Sve_CreateTrueMaskByte:
case NI_Sve_CreateTrueMaskDouble:
case NI_Sve_CreateTrueMaskInt16:
case NI_Sve_CreateTrueMaskInt32:
case NI_Sve_CreateTrueMaskInt64:
case NI_Sve_CreateTrueMaskSByte:
case NI_Sve_CreateTrueMaskSingle:
case NI_Sve_CreateTrueMaskUInt16:
case NI_Sve_CreateTrueMaskUInt32:
case NI_Sve_CreateTrueMaskUInt64:
immLowerBound = (int)SVE_PATTERN_POW2;
immUpperBound = (int)SVE_PATTERN_ALL;
break;

default:
unreached();
}
Expand Down Expand Up @@ -2179,6 +2193,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
break;
}

default:
{
return nullptr;
Expand All @@ -2188,4 +2203,43 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
return retNode;
}

//------------------------------------------------------------------------
// convertHWIntrinsicFromMask: Convert a HW instrinsic vector node to a mask
//
// Arguments:
// node -- The node to convert
// simdBaseJitType -- the base jit type of the converted node
// simdSize -- the simd size of the converted node
//
// Return Value:
// The node converted to the a mask type
//
GenTree* Compiler::convertHWIntrinsicToMask(var_types type,
GenTree* node,
CorInfoType simdBaseJitType,
unsigned simdSize)
{
// ConvertVectorToMask uses cmpne which requires an embedded mask.
GenTree* embeddedMask = gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, embeddedMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType,
simdSize);
}

//------------------------------------------------------------------------
// convertHWIntrinsicFromMask: Convert a HW instrinsic mask node to a vector
//
// Arguments:
// node -- The node to convert
// type -- The type of the node to convert to
//
// Return Value:
// The node converted to the given type
//
GenTree* Compiler::convertHWIntrinsicFromMask(GenTreeHWIntrinsic* node, var_types type)
{
assert(node->TypeGet() == TYP_MASK);
return gtNewSimdHWIntrinsicNode(type, node, NI_Sve_ConvertMaskToVector, node->GetSimdBaseJitType(),
node->GetSimdSize());
}

#endif // FEATURE_HW_INTRINSICS
Loading
Loading