Skip to content

Commit

Permalink
Implement StoreVector64x2 and StoreVector128x2 for Arm64 (#92109)
Browse files Browse the repository at this point in the history
* Implement StoreVector128x2 for Arm64

* Remove redundant implmentations

* Implement StoreVector64x2 for Arm64

* Remove StoreVector64x2 implementation for Arm64

This reverts commit 49ef72e.

* Fix instruction type for the StoreVector128x2 intrinsic

* Review comments:

* Arrange APIs alphabetically

* Add StoreVector64x2

* fix the invalid instructions

* Add test cases

* Update src/coreclr/jit/hwintrinsicarm64.cpp

Co-authored-by: Bruce Forstall <brucefo@microsoft.com>

---------

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
Co-authored-by: Bruce Forstall <brucefo@microsoft.com>
  • Loading branch information
3 people authored Sep 29, 2023
1 parent f401263 commit 0cc9f21
Show file tree
Hide file tree
Showing 9 changed files with 621 additions and 17 deletions.
42 changes: 42 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,48 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_AdvSimd_StoreVector64x2:
case NI_AdvSimd_Arm64_StoreVector128x2:
{
assert(sig->numArgs == 2);
assert(retType == TYP_VOID);

CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;

argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg2, &argClass)));
op2 = impPopStack().val;
unsigned fieldCount = info.compCompHnd->getClassNumInstanceFields(argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg1, &argClass)));
op1 = getArgForHWIntrinsic(argType, argClass);

assert(op2->TypeGet() == TYP_STRUCT);
if (op1->OperIs(GT_CAST))
{
// Although the API specifies a pointer, if what we have is a BYREF, that's what
// we really want, so throw away the cast.
if (op1->gtGetOp1()->TypeGet() == TYP_BYREF)
{
op1 = op1->gtGetOp1();
}
}

if (!op2->OperIs(GT_LCL_VAR))
{
unsigned tmp = lvaGrabTemp(true DEBUGARG("StoreVectorNx2 temp tree"));

impStoreTemp(tmp, op2, CHECK_SPILL_NONE);
op2 = gtNewLclvNode(tmp, argType);
}
op2 = gtConvertTableOpToFieldList(op2, fieldCount);

info.compNeedsConsecutiveRegisters = true;
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize);
break;
}

case NI_Vector64_Sum:
case NI_Vector128_Sum:
{
Expand Down
28 changes: 28 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,34 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
GetEmitter()->emitIns_R_R_R(ins, emitTypeSize(intrin.baseType), op2Reg, op3Reg, op1Reg);
break;

case NI_AdvSimd_StoreVector64x2:
case NI_AdvSimd_Arm64_StoreVector128x2:
{
unsigned regCount = 0;

assert(intrin.op2->OperIsFieldList());

GenTreeFieldList* fieldList = intrin.op2->AsFieldList();
GenTree* firstField = fieldList->Uses().GetHead()->GetNode();
op2Reg = firstField->GetRegNum();

#ifdef DEBUG
regNumber argReg = op2Reg;
for (GenTreeFieldList::Use& use : fieldList->Uses())
{
regCount++;

GenTree* argNode = use.GetNode();
assert(argReg == argNode->GetRegNum());
argReg = REG_NEXT(argReg);
}
assert(regCount == 2);
#endif

GetEmitter()->emitIns_R_R(ins, emitSize, op2Reg, op1Reg, opt);
break;
}

case NI_Vector64_CreateScalarUnsafe:
case NI_Vector128_CreateScalarUnsafe:
if (intrin.op1->isContainedFltOrDblImmed())
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ HARDWARE_INTRINSIC(AdvSimd, SignExtendWideningUpper,
HARDWARE_INTRINSIC(AdvSimd, SqrtScalar, 8, 1, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_fsqrt, INS_fsqrt}, HW_Category_SIMD, HW_Flag_SIMDScalar)
HARDWARE_INTRINSIC(AdvSimd, Store, -1, 2, true, {INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromSecondArg|HW_Flag_NoCodeGen)
HARDWARE_INTRINSIC(AdvSimd, StoreSelectedScalar, -1, 3, true, {INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1, INS_st1}, HW_Category_MemoryStore, HW_Flag_BaseTypeFromSecondArg|HW_Flag_HasImmediateOperand|HW_Flag_SIMDScalar|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(AdvSimd, StoreVector64x2, 8, 2, true, {INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_invalid, INS_invalid, INS_st2, INS_invalid}, HW_Category_MemoryStore, HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_NeedsConsecutiveRegisters)
HARDWARE_INTRINSIC(AdvSimd, Subtract, -1, 2, true, {INS_sub, INS_sub, INS_sub, INS_sub, INS_sub, INS_sub, INS_sub, INS_sub, INS_fsub, INS_invalid}, HW_Category_SIMD, HW_Flag_NoFlag)
HARDWARE_INTRINSIC(AdvSimd, SubtractHighNarrowingLower, 8, 2, true, {INS_subhn, INS_subhn, INS_subhn, INS_subhn, INS_subhn, INS_subhn, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_NoFlag)
HARDWARE_INTRINSIC(AdvSimd, SubtractHighNarrowingUpper, 16, 3, true, {INS_subhn2, INS_subhn2, INS_subhn2, INS_subhn2, INS_subhn2, INS_subhn2, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_HasRMWSemantics)
Expand Down Expand Up @@ -645,6 +646,7 @@ HARDWARE_INTRINSIC(AdvSimd_Arm64, StorePair,
HARDWARE_INTRINSIC(AdvSimd_Arm64, StorePairScalar, 8, 3, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_stp, INS_stp, INS_invalid, INS_invalid, INS_stp, INS_invalid}, HW_Category_MemoryStore, HW_Flag_BaseTypeFromSecondArg|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(AdvSimd_Arm64, StorePairScalarNonTemporal, 8, 3, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_stnp, INS_stnp, INS_invalid, INS_invalid, INS_stnp, INS_invalid}, HW_Category_MemoryStore, HW_Flag_BaseTypeFromSecondArg|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(AdvSimd_Arm64, StorePairNonTemporal, -1, 3, true, {INS_stnp, INS_stnp, INS_stnp, INS_stnp, INS_stnp, INS_stnp, INS_stnp, INS_stnp, INS_stnp, INS_stp}, HW_Category_MemoryStore, HW_Flag_BaseTypeFromSecondArg|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(AdvSimd_Arm64, StoreVector128x2, 16, 2, true, {INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_st2, INS_st2}, HW_Category_MemoryStore, HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_NeedsConsecutiveRegisters)
HARDWARE_INTRINSIC(AdvSimd_Arm64, Subtract, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_fsub}, HW_Category_SIMD, HW_Flag_NoFlag)
HARDWARE_INTRINSIC(AdvSimd_Arm64, SubtractSaturateScalar, 8, 2, true, {INS_sqsub, INS_uqsub, INS_sqsub, INS_uqsub, INS_sqsub, INS_uqsub, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_SIMDScalar)
HARDWARE_INTRINSIC(AdvSimd_Arm64, TransposeEven, -1, 2, true, {INS_trn1, INS_trn1, INS_trn1, INS_trn1, INS_trn1, INS_trn1, INS_trn1, INS_trn1, INS_trn1, INS_trn1}, HW_Category_SIMD, HW_Flag_NoFlag)
Expand Down
50 changes: 33 additions & 17 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,25 +1548,41 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou

else if (HWIntrinsicInfo::NeedsConsecutiveRegisters(intrin.id))
{
if ((intrin.id == NI_AdvSimd_VectorTableLookup) || (intrin.id == NI_AdvSimd_Arm64_VectorTableLookup))
{
assert(intrin.op2 != nullptr);
srcCount += BuildOperandUses(intrin.op2);
}
else
switch (intrin.id)
{
assert(intrin.op2 != nullptr);
assert(intrin.op3 != nullptr);
assert((intrin.id == NI_AdvSimd_VectorTableLookupExtension) ||
(intrin.id == NI_AdvSimd_Arm64_VectorTableLookupExtension));
assert(isRMW);
srcCount += BuildConsecutiveRegistersForUse(intrin.op2, intrin.op1);
srcCount += BuildDelayFreeUses(intrin.op3, intrin.op1);
case NI_AdvSimd_VectorTableLookup:
case NI_AdvSimd_Arm64_VectorTableLookup:
assert(intrin.op2 != nullptr);
srcCount += BuildOperandUses(intrin.op2);
assert(dstCount == 1);
buildInternalRegisterUses();
BuildDef(intrinsicTree);
*pDstCount = 1;
break;

case NI_AdvSimd_VectorTableLookupExtension:
case NI_AdvSimd_Arm64_VectorTableLookupExtension:
assert(intrin.op2 != nullptr);
assert(intrin.op3 != nullptr);
assert(isRMW);
srcCount += BuildConsecutiveRegistersForUse(intrin.op2, intrin.op1);
srcCount += BuildDelayFreeUses(intrin.op3, intrin.op1);
assert(dstCount == 1);
buildInternalRegisterUses();
BuildDef(intrinsicTree);
*pDstCount = 1;
break;
case NI_AdvSimd_StoreVector64x2:
case NI_AdvSimd_Arm64_StoreVector128x2:
assert(intrin.op1 != nullptr);
srcCount += BuildConsecutiveRegistersForUse(intrin.op2);
assert(dstCount == 0);
buildInternalRegisterUses();
*pDstCount = 0;
break;
default:
noway_assert(!"Not a supported as multiple consecutive register intrinsic");
}
assert(dstCount == 1);
buildInternalRegisterUses();
BuildDef(intrinsicTree);
*pDstCount = 1;
return srcCount;
}
else if (intrin.op2 != nullptr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3176,6 +3176,56 @@ internal Arm64() { }
/// </summary>
public static unsafe void StorePairScalarNonTemporal(uint* address, Vector64<uint> value1, Vector64<uint> value2) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.16B, Vn+1.16B }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(byte* address, (Vector128<byte> Value1, Vector128<byte> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.16B, Vn+1.16B }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(sbyte* address, (Vector128<sbyte> Value1, Vector128<sbyte> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.8H, Vn+1.8H }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(short* address, (Vector128<short> Value1, Vector128<short> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.8H, Vn+1.8H }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(ushort* address, (Vector128<ushort> Value1, Vector128<ushort> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.4S, Vn+1.4S }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(int* address, (Vector128<int> Value1, Vector128<int> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.4S, Vn+1.4S }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(uint* address, (Vector128<uint> Value1, Vector128<uint> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.2D, Vn+1.2D }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(long* address, (Vector128<long> Value1, Vector128<long> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.2D, Vn+1.2D }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(ulong* address, (Vector128<ulong> Value1, Vector128<ulong> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.4S, Vn+1.4S }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(float* address, (Vector128<float> Value1, Vector128<float> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.2D, Vn+1.2D }, [Xn]
/// </summary>
public static unsafe void StoreVector128x2(double* address, (Vector128<double> Value1, Vector128<double> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// float64x2_t vsubq_f64 (float64x2_t a, float64x2_t b)
/// A64: FSUB Vd.2D, Vn.2D, Vm.2D
Expand Down Expand Up @@ -14437,6 +14487,41 @@ internal Arm64() { }
/// </summary>
public static unsafe void StoreSelectedScalar(ulong* address, Vector128<ulong> value, [ConstantExpected(Max = (byte)(1))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.8B, Vn+1.8B }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(byte* address, (Vector64<byte> Value1, Vector64<byte> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.8B, Vn+1.8B }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(sbyte* address, (Vector64<sbyte> Value1, Vector64<sbyte> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.4H, Vn+1.4H }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(short* address, (Vector64<short> Value1, Vector64<short> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.4H, Vn+1.4H }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(ushort* address, (Vector64<ushort> Value1, Vector64<ushort> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.2S, Vn+1.2S }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(int* address, (Vector64<int> Value1, Vector64<int> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.2S, Vn+1.2S }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(uint* address, (Vector64<uint> Value1, Vector64<uint> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// A64: ST2 { Vn.2S, Vn+1.2S }, [Xn]
/// </summary>
public static unsafe void StoreVector64x2(float* address, (Vector64<float> Value1, Vector64<float> Value2) value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint8x8_t vsub_u8 (uint8x8_t a, uint8x8_t b)
/// A32: VSUB.I8 Dd, Dn, Dm
Expand Down
Loading

0 comments on commit 0cc9f21

Please sign in to comment.