Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add CreateWhileLessThan* #100949

Merged
merged 8 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,17 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
break;

case NI_Sve_CreateWhileLessThanMask8Bit:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like how we have these target specific switch statements in hwintrinsic.cpp, it makes the code hard to follow. I highly expect the number of these special cases to grow too.

I'd much prefer it if each of these instrinsics (including all the neon and X86 ones) were marked with SpecialImport. The special import cases would then have to duplicate the common get args code, but it's only a few lines (which could go into a helper). Not going to do it for this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍, we've definitely done that for almost all the x86/x64 code, there's only a couple special handlers left for that platform (like Sse42.Crc32). It would be nice to get both platforms doing this overall consistently here.

case NI_Sve_CreateWhileLessThanOrEqualMask8Bit:
case NI_Sve_CreateWhileLessThanMask16Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask16Bit:
case NI_Sve_CreateWhileLessThanMask32Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask32Bit:
case NI_Sve_CreateWhileLessThanMask64Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask64Bit:
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(sigReader.op1JitType);
break;

default:
break;
}
Expand Down
34 changes: 34 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,40 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
GetEmitter()->emitIns_R_PATTERN(ins, emitSize, targetReg, opt, SVE_PATTERN_ALL);
break;

case NI_Sve_CreateWhileLessThanMask8Bit:
case NI_Sve_CreateWhileLessThanMask16Bit:
case NI_Sve_CreateWhileLessThanMask32Bit:
case NI_Sve_CreateWhileLessThanMask64Bit:
{
// Emit size and instruction is based on the scalar operands.
var_types auxType = node->GetAuxiliaryType();
emitSize = emitActualTypeSize(auxType);
if (varTypeIsUnsigned(auxType))
{
ins = INS_sve_whilelo;
}

GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
break;
}

case NI_Sve_CreateWhileLessThanOrEqualMask8Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask16Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask32Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask64Bit:
{
// Emit size and instruction is based on the scalar operands.
var_types auxType = node->GetAuxiliaryType();
emitSize = emitActualTypeSize(auxType);
if (varTypeIsUnsigned(auxType))
{
ins = INS_sve_whilels;
}

GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
break;
}

default:
unreached();
}
Expand Down
9 changes: 9 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ HARDWARE_INTRINSIC(Sve, CreateTrueMaskUInt16,
HARDWARE_INTRINSIC(Sve, CreateTrueMaskUInt32, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskUInt64, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)

HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask16Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask32Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask16Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask32Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)

HARDWARE_INTRINSIC(Sve, LoadVector, -1, 2, true, {INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1h, INS_sve_ld1h, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_LowMaskedOperation)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,221 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateTrueMaskUInt64([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask16Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b16[_s32](int32_t op1, int32_t op2)
/// WHILELT Presult.H, Wop1, Wop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b16[_s64](int64_t op1, int64_t op2)
/// WHILELT Presult.H, Xop1, Xop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b16[_u32](uint32_t op1, uint32_t op2)
/// WHILELO Presult.H, Wop1, Wop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b16[_u64](uint64_t op1, uint64_t op2)
/// WHILELO Presult.H, Xop1, Xop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask32Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b32[_s32](int32_t op1, int32_t op2)
/// WHILELT Presult.S, Wop1, Wop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b32[_s64](int64_t op1, int64_t op2)
/// WHILELT Presult.S, Xop1, Xop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b32[_u32](uint32_t op1, uint32_t op2)
/// WHILELO Presult.S, Wop1, Wop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b32[_u64](uint64_t op1, uint64_t op2)
/// WHILELO Presult.S, Xop1, Xop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask64Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b64[_s32](int32_t op1, int32_t op2)
/// WHILELT Presult.D, Wop1, Wop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b64[_s64](int64_t op1, int64_t op2)
/// WHILELT Presult.D, Xop1, Xop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b64[_u32](uint32_t op1, uint32_t op2)
/// WHILELO Presult.D, Wop1, Wop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b64[_u64](uint64_t op1, uint64_t op2)
/// WHILELO Presult.D, Xop1, Xop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask8Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b8[_s32](int32_t op1, int32_t op2)
/// WHILELT Presult.B, Wop1, Wop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b8[_s64](int64_t op1, int64_t op2)
/// WHILELT Presult.B, Xop1, Xop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b8[_u32](uint32_t op1, uint32_t op2)
/// WHILELO Presult.B, Wop1, Wop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b8[_u64](uint64_t op1, uint64_t op2)
/// WHILELO Presult.B, Xop1, Xop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask16Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b16[_s32](int32_t op1, int32_t op2)
/// WHILELE Presult.H, Wop1, Wop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b16[_s64](int64_t op1, int64_t op2)
/// WHILELE Presult.H, Xop1, Xop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b16[_u32](uint32_t op1, uint32_t op2)
/// WHILELS Presult.H, Wop1, Wop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b16[_u64](uint64_t op1, uint64_t op2)
/// WHILELS Presult.H, Xop1, Xop2
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask32Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b32[_s32](int32_t op1, int32_t op2)
/// WHILELE Presult.S, Wop1, Wop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b32[_s64](int64_t op1, int64_t op2)
/// WHILELE Presult.S, Xop1, Xop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b32[_u32](uint32_t op1, uint32_t op2)
/// WHILELS Presult.S, Wop1, Wop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b32[_u64](uint64_t op1, uint64_t op2)
/// WHILELS Presult.S, Xop1, Xop2
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask64Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b64[_s32](int32_t op1, int32_t op2)
/// WHILELE Presult.D, Wop1, Wop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b64[_s64](int64_t op1, int64_t op2)
/// WHILELE Presult.D, Xop1, Xop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b64[_u32](uint32_t op1, uint32_t op2)
/// WHILELS Presult.D, Wop1, Wop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b64[_u64](uint64_t op1, uint64_t op2)
/// WHILELS Presult.D, Xop1, Xop2
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask8Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b8[_s32](int32_t op1, int32_t op2)
/// WHILELE Presult.B, Wop1, Wop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b8[_s64](int64_t op1, int64_t op2)
/// WHILELE Presult.B, Xop1, Xop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b8[_u32](uint32_t op1, uint32_t op2)
/// WHILELS Presult.B, Wop1, Wop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b8[_u64](uint64_t op1, uint64_t op2)
/// WHILELS Presult.B, Xop1, Xop2
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// LoadVector : Unextended load

Expand Down
Loading
Loading