Skip to content

Commit

Permalink
ARM64-SVE: GatherPrefetch (#103826)
Browse files Browse the repository at this point in the history
* ARM64-SVE: GatherPrefetch

* Ensure invalid immediates are invalid

* fixed resolved conflicts

---------

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
  • Loading branch information
a74nh and kunalspathak authored Jul 2, 2024
1 parent c65d921 commit 7b9f1c3
Show file tree
Hide file tree
Showing 13 changed files with 1,548 additions and 22 deletions.
4 changes: 4 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27449,6 +27449,10 @@ bool GenTreeHWIntrinsic::OperIsMemoryLoad(GenTree** pAddr) const
addr = Op(2);
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Expand Down
24 changes: 24 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,14 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
break;

#elif defined(TARGET_ARM64)
case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
assert(varTypeIsSIMD(op2->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op2ClsHnd));
break;

case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Expand Down Expand Up @@ -1885,6 +1893,22 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
assert(!isScalar);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);

#if defined(TARGET_ARM64)
switch (intrinsic)
{
case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
assert(varTypeIsSIMD(op3->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
break;

default:
break;
}
#endif
break;
}

Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ void HWIntrinsicInfo::lookupImmBounds(
}
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
case NI_Sve_PrefetchBytes:
case NI_Sve_PrefetchInt16:
case NI_Sve_PrefetchInt32:
Expand Down
89 changes: 67 additions & 22 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,22 +1615,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
assert(hasImmediateOperand);
assert(HWIntrinsicInfo::HasEnumOperand(intrin.id));
if (intrin.op3->IsCnsIntOrI())
{
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize,
(insSvePrfop)intrin.op3->AsIntConCommon()->IconValue(), op1Reg,
op2Reg, 0);
}
else
HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
assert(!intrin.op3->isContainedIntOrIImmed());

HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize, prfop, op1Reg, op2Reg, 0);
}
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize, prfop, op1Reg, op2Reg, 0);
}
break;
}
Expand Down Expand Up @@ -1992,6 +1981,61 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
{
assert(hasImmediateOperand);

if (!varTypeIsSIMD(intrin.op2->gtType))
{
// GatherPrefetch...(Vector<T> mask, T* address, Vector<T2> indices, SvePrefetchType prefetchType)

assert(intrin.numOperands == 4);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;

if (baseSize == EA_8BYTE)
{
// Index is multiplied.
sopt = (ins == INS_sve_prfb) ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_LSL_N;
}
else
{
// Index is sign or zero extended to 64bits, then multiplied.
assert(baseSize == EA_4BYTE);
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
: INS_OPTS_SCALABLE_S_SXTW;

sopt = (ins == INS_sve_prfb) ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_MOD_N;
}

HWIntrinsicImmOpHelper helper(this, intrin.op4, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_R(ins, emitSize, prfop, op1Reg, op2Reg, op3Reg, opt, sopt);
}
}
else
{
// GatherPrefetch...(Vector<T> mask, Vector<T2> addresses, SvePrefetchType prefetchType)

opt = emitter::optGetSveInsOpt(emitTypeSize(node->GetAuxiliaryType()));

assert(intrin.numOperands == 3);
HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize, prfop, op1Reg, op2Reg, 0, opt);
}
}

break;
}

case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Expand All @@ -2009,14 +2053,14 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
// GatherVector...(Vector<T> mask, T* address, Vector<T2> indices)

assert(intrin.numOperands == 3);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;

if (baseSize == EA_8BYTE)
{
// Index is multiplied.
insScalableOpts sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_LSL_N;
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_LSL_N;
}
else
{
Expand All @@ -2025,10 +2069,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
: INS_OPTS_SCALABLE_S_SXTW;

insScalableOpts sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_MOD_N;
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_MOD_N;
}

GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
}
else
{
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated,
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, GatherPrefetch16Bit, -1, -1, false, {INS_invalid, INS_invalid, INS_sve_prfh, INS_sve_prfh, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherPrefetch32Bit, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_prfw, INS_sve_prfw, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherPrefetch64Bit, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_prfd, INS_sve_prfd, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherPrefetch8Bit, -1, -1, false, {INS_sve_prfb, INS_sve_prfb, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherVector, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, GatherVectorByteZeroExtend, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1b, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, GatherVectorInt16SignExtend, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1sh, INS_sve_ld1sh, INS_sve_ld1sh, INS_sve_ld1sh, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
23 changes: 23 additions & 0 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3427,6 +3427,29 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
}
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
assert(hasImmediateOperand);
if (!varTypeIsSIMD(intrin.op2->gtType))
{
assert(varTypeIsIntegral(intrin.op4));
if (intrin.op4->IsCnsIntOrI())
{
MakeSrcContained(node, intrin.op4);
}
}
else
{
assert(varTypeIsIntegral(intrin.op3));
if (intrin.op3->IsCnsIntOrI())
{
MakeSrcContained(node, intrin.op3);
}
}
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
Expand Down
37 changes: 37 additions & 0 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,20 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
needBranchTargetReg = !intrin.op1->isContainedIntOrIImmed();
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
if (!varTypeIsSIMD(intrin.op2->gtType))
{
needBranchTargetReg = !intrin.op4->isContainedIntOrIImmed();
}
else
{
needBranchTargetReg = !intrin.op3->isContainedIntOrIImmed();
}
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
Expand Down Expand Up @@ -2005,6 +2019,29 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
srcCount += BuildAddrUses(intrin.op2);
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
case NI_Sve_GatherVectorInt16WithByteOffsetsSignExtend:
case NI_Sve_GatherVectorInt32SignExtend:
case NI_Sve_GatherVectorInt32WithByteOffsetsSignExtend:
case NI_Sve_GatherVectorSByteSignExtend:
case NI_Sve_GatherVectorUInt16WithByteOffsetsZeroExtend:
case NI_Sve_GatherVectorUInt16ZeroExtend:
case NI_Sve_GatherVectorUInt32WithByteOffsetsZeroExtend:
case NI_Sve_GatherVectorUInt32ZeroExtend:
assert(intrinsicTree->OperIsMemoryLoadOrStore());
if (!varTypeIsSIMD(intrin.op2->gtType))
{
srcCount += BuildAddrUses(intrin.op2);
break;
}
FALLTHROUGH;

default:
{
SingleTypeRegSet candidates = lowVectorOperandNum == 2 ? lowVectorCandidates : RBM_NONE;
Expand Down
Loading

0 comments on commit 7b9f1c3

Please sign in to comment.