Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Sve.StoreNarrowing() #102605

Merged
merged 3 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26711,6 +26711,7 @@ bool GenTreeHWIntrinsic::OperIsMemoryStore(GenTree** pAddr) const
case NI_Sve_StoreAndZipx2:
case NI_Sve_StoreAndZipx3:
case NI_Sve_StoreAndZipx4:
case NI_Sve_StoreNarrowing:
addr = Op(2);
break;
#endif // TARGET_ARM64
Expand Down
28 changes: 28 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,34 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_StoreNarrowing:
{
assert(sig->numArgs == 3);
assert(retType == TYP_VOID);

CORINFO_ARG_LIST_HANDLE arg = sig->args;
arg = info.compCompHnd->getArgNext(arg);
CORINFO_CLASS_HANDLE argClass = info.compCompHnd->getArgClass(sig, arg);
CorInfoType ptrType = getBaseJitTypeAndSizeOfSIMDType(argClass);
CORINFO_CLASS_HANDLE tmpClass = NO_CLASS_HANDLE;

// The size of narrowed target elements is determined from the second argument of StoreNarrowing().
// Thus, we first extract the datatype of a pointer passed in the second argument and then store it as the
// auxiliary type of intrinsic. This auxiliary type is then used in the codegen to choose the correct
// instruction to emit.
ptrType = strip(info.compCompHnd->getArgType(sig, arg, &tmpClass));
SwapnilGaikwad marked this conversation as resolved.
Show resolved Hide resolved
assert(ptrType == CORINFO_TYPE_PTR);
ptrType = info.compCompHnd->getChildType(argClass, &tmpClass);
assert(ptrType < simdBaseJitType);

op3 = impPopStack().val;
op2 = impPopStack().val;
op1 = impPopStack().val;
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(ptrType);
break;
}

default:
{
return nullptr;
Expand Down
9 changes: 9 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
ins = varTypeIsUnsigned(intrin.baseType) ? INS_umsubl : INS_smsubl;
break;

case NI_Sve_StoreNarrowing:
ins = HWIntrinsicInfo::lookupIns(intrin.id, node->GetAuxiliaryType());
break;

default:
ins = HWIntrinsicInfo::lookupIns(intrin.id, intrin.baseType);
break;
Expand Down Expand Up @@ -1773,6 +1777,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_StoreNarrowing:
opt = emitter::optGetSveInsOpt(emitTypeSize(intrin.baseType));
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);
break;

case NI_Sve_UnzipEven:
case NI_Sve_UnzipOdd:
case NI_Sve_ZipHigh:
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ HARDWARE_INTRINSIC(Sve, SignExtend8,
HARDWARE_INTRINSIC(Sve, SignExtendWideningLower, -1, 1, true, {INS_sve_sunpklo, INS_invalid, INS_sve_sunpklo, INS_invalid, INS_sve_sunpklo, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Sve, SignExtendWideningUpper, -1, 1, true, {INS_sve_sunpkhi, INS_invalid, INS_sve_sunpkhi, INS_invalid, INS_sve_sunpkhi, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Sve, StoreAndZip, -1, 3, true, {INS_sve_st1b, INS_sve_st1b, INS_sve_st1h, INS_sve_st1h, INS_sve_st1w, INS_sve_st1w, INS_sve_st1d, INS_sve_st1d, INS_sve_st1w, INS_sve_st1d}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, StoreNarrowing, -1, 3, true, {INS_sve_st1b, INS_sve_st1b, INS_sve_st1h, INS_sve_st1h, INS_sve_st1w, INS_sve_st1w, INS_sve_st1d, INS_sve_st1d, INS_invalid, INS_invalid}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, Subtract, -1, 2, true, {INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_sub, INS_sve_fsub, INS_sve_fsub}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, SubtractSaturate, -1, 2, true, {INS_sve_sqsub, INS_sve_uqsub, INS_sve_sqsub, INS_sve_uqsub, INS_sve_sqsub, INS_sve_uqsub, INS_sve_sqsub, INS_sve_uqsub, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, UnzipEven, -1, 2, true, {INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1, INS_sve_uzp1}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2998,6 +2998,80 @@ internal Arm64() { }
/// ST4D {Zdata0.D - Zdata3.D}, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreAndZip(Vector<ulong> mask, ulong* address, (Vector<ulong> Value1, Vector<ulong> Value2, Vector<ulong> Value3, Vector<ulong> Value4) data) { throw new PlatformNotSupportedException(); }
/// Truncate to 8 bits and store

/// <summary>
/// void svst1b[_s16](svbool_t pg, int8_t *base, svint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<short> mask, sbyte* address, Vector<short> data) { throw new PlatformNotSupportedException(); }


/// <summary>
/// void svst1b[_s32](svbool_t pg, int8_t *base, svint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, sbyte* address, Vector<int> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_s32](svbool_t pg, int16_t *base, svint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, short* address, Vector<int> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_s64](svbool_t pg, int8_t *base, svint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, sbyte* address, Vector<long> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_s64](svbool_t pg, int16_t *base, svint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, short* address, Vector<long> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1w[_s64](svbool_t pg, int32_t *base, svint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, int* address, Vector<long> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_u16](svbool_t pg, uint8_t *base, svuint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ushort> mask, byte* address, Vector<ushort> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_u32](svbool_t pg, uint8_t *base, svuint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, byte* address, Vector<uint> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_u32](svbool_t pg, uint16_t *base, svuint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, ushort* address, Vector<uint> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1b[_u64](svbool_t pg, uint8_t *base, svuint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, byte* address, Vector<ulong> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1h[_u64](svbool_t pg, uint16_t *base, svuint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, ushort* address, Vector<ulong> data) { throw new PlatformNotSupportedException(); }

/// <summary>
/// void svst1w[_u64](svbool_t pg, uint32_t *base, svuint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, uint* address, Vector<ulong> data) { throw new PlatformNotSupportedException(); }


/// Subtract : Subtract
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3094,6 +3094,80 @@ internal Arm64() { }
/// ST4D {Zdata0.D - Zdata3.D}, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreAndZip(Vector<ulong> mask, ulong* address, (Vector<ulong> Value1, Vector<ulong> Value2, Vector<ulong> Value3, Vector<ulong> Value4) data) => StoreAndZip(mask, address, data);
/// Truncate to 8 bits and store


/// <summary>
/// void svst1b[_s16](svbool_t pg, int8_t *base, svint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<short> mask, sbyte* address, Vector<short> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_s32](svbool_t pg, int8_t *base, svint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, sbyte* address, Vector<int> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_s32](svbool_t pg, int16_t *base, svint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<int> mask, short* address, Vector<int> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_s64](svbool_t pg, int8_t *base, svint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, sbyte* address, Vector<long> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_s64](svbool_t pg, int16_t *base, svint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, short* address, Vector<long> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1w[_s64](svbool_t pg, int32_t *base, svint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<long> mask, int* address, Vector<long> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_u16](svbool_t pg, uint8_t *base, svuint16_t data)
/// ST1B Zdata.H, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ushort> mask, byte* address, Vector<ushort> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_u32](svbool_t pg, uint8_t *base, svuint32_t data)
/// ST1B Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, byte* address, Vector<uint> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_u32](svbool_t pg, uint16_t *base, svuint32_t data)
/// ST1H Zdata.S, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<uint> mask, ushort* address, Vector<uint> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1b[_u64](svbool_t pg, uint8_t *base, svuint64_t data)
/// ST1B Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, byte* address, Vector<ulong> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1h[_u64](svbool_t pg, uint16_t *base, svuint64_t data)
/// ST1H Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, ushort* address, Vector<ulong> data) => StoreNarrowing(mask, address, data);

/// <summary>
/// void svst1w[_u64](svbool_t pg, uint32_t *base, svuint64_t data)
/// ST1W Zdata.D, Pg, [Xbase, #0, MUL VL]
/// </summary>
public static unsafe void StoreNarrowing(Vector<ulong> mask, uint* address, Vector<ulong> data) => StoreNarrowing(mask, address, data);


/// Subtract : Subtract
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4588,6 +4588,19 @@ internal Arm64() { }
public static unsafe void StoreAndZip(System.Numerics.Vector<ulong> mask, ulong* address, (System.Numerics.Vector<ulong> Value1, System.Numerics.Vector<ulong> Value2, System.Numerics.Vector<ulong> Value3) data) { throw null; }
public static unsafe void StoreAndZip(System.Numerics.Vector<ulong> mask, ulong* address, (System.Numerics.Vector<ulong> Value1, System.Numerics.Vector<ulong> Value2, System.Numerics.Vector<ulong> Value3, System.Numerics.Vector<ulong> Value4) data) { throw null; }

public static unsafe void StoreNarrowing(System.Numerics.Vector<short> mask, sbyte* address, System.Numerics.Vector<short> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<int> mask, sbyte* address, System.Numerics.Vector<int> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<int> mask, short* address, System.Numerics.Vector<int> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<long> mask, sbyte* address, System.Numerics.Vector<long> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<long> mask, short* address, System.Numerics.Vector<long> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<long> mask, int* address, System.Numerics.Vector<long> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ushort> mask, byte* address, System.Numerics.Vector<ushort> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<uint> mask, byte* address, System.Numerics.Vector<uint> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<uint> mask, ushort* address, System.Numerics.Vector<uint> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ulong> mask, byte* address, System.Numerics.Vector<ulong> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ulong> mask, ushort* address, System.Numerics.Vector<ulong> data) { throw null; }
public static unsafe void StoreNarrowing(System.Numerics.Vector<ulong> mask, uint* address, System.Numerics.Vector<ulong> data) { throw null; }

public static System.Numerics.Vector<sbyte> Subtract(System.Numerics.Vector<sbyte> left, System.Numerics.Vector<sbyte> right) { throw null; }
public static System.Numerics.Vector<short> Subtract(System.Numerics.Vector<short> left, System.Numerics.Vector<short> right) { throw null; }
public static System.Numerics.Vector<int> Subtract(System.Numerics.Vector<int> left, System.Numerics.Vector<int> right) { throw null; }
Expand Down
Loading
Loading