Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add AddAcross #101674

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/coreclr/jit/codegenarm64test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5314,11 +5314,11 @@ void CodeGen::genArm64EmitterUnitTestsSve()
#endif // ALL_ARM64_EMITTER_UNIT_TESTS_SVE_UNSUPPORTED

// IF_SVE_AI_3A
theEmitter->emitIns_R_R_R(INS_sve_saddv, EA_1BYTE, REG_V1, REG_P4, REG_V2,
theEmitter->emitIns_R_R_R(INS_sve_saddv, EA_SCALABLE, REG_V1, REG_P4, REG_V2,
Copy link
Contributor Author

@a74nh a74nh Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the codegen changes:

For these instructions, arg2 (EA_1BYTE etc) is never used as the return value is dependent on the input type which is already specified in opt.
Switching arg2 to EA_SCALABLE means there is no need to write special hwinstrinsiccodegen code.

I've changed the bare minimal of instructions needed to make this patch work. There are quite a few more reduction like instructions - we should do those as we get to them in the API

INS_OPTS_SCALABLE_B); // SADDV <Dd>, <Pg>, <Zn>.<T>
theEmitter->emitIns_R_R_R(INS_sve_saddv, EA_2BYTE, REG_V2, REG_P5, REG_V3,
theEmitter->emitIns_R_R_R(INS_sve_saddv, EA_SCALABLE, REG_V2, REG_P5, REG_V3,
INS_OPTS_SCALABLE_H); // SADDV <Dd>, <Pg>, <Zn>.<T>
theEmitter->emitIns_R_R_R(INS_sve_uaddv, EA_4BYTE, REG_V3, REG_P6, REG_V4,
theEmitter->emitIns_R_R_R(INS_sve_uaddv, EA_SCALABLE, REG_V3, REG_P6, REG_V4,
INS_OPTS_SCALABLE_S); // UADDV <Dd>, <Pg>, <Zn>.<T>

// IF_SVE_AJ_3A
Expand Down Expand Up @@ -6768,15 +6768,15 @@ void CodeGen::genArm64EmitterUnitTestsSve()
#endif // ALL_ARM64_EMITTER_UNIT_TESTS_SVE_UNSUPPORTED

// IF_SVE_HE_3A
theEmitter->emitIns_R_R_R(INS_sve_faddv, EA_2BYTE, REG_V21, REG_P7, REG_V7,
theEmitter->emitIns_R_R_R(INS_sve_faddv, EA_SCALABLE, REG_V21, REG_P7, REG_V7,
INS_OPTS_SCALABLE_H); // FADDV <V><d>, <Pg>, <Zn>.<T>
theEmitter->emitIns_R_R_R(INS_sve_fmaxnmv, EA_2BYTE, REG_V22, REG_P6, REG_V6,
theEmitter->emitIns_R_R_R(INS_sve_fmaxnmv, EA_SCALABLE, REG_V22, REG_P6, REG_V6,
INS_OPTS_SCALABLE_H); // FMAXNMV <V><d>, <Pg>, <Zn>.<T>
theEmitter->emitIns_R_R_R(INS_sve_fmaxv, EA_4BYTE, REG_V23, REG_P5, REG_V5,
theEmitter->emitIns_R_R_R(INS_sve_fmaxv, EA_SCALABLE, REG_V23, REG_P5, REG_V5,
INS_OPTS_SCALABLE_S); // FMAXV <V><d>, <Pg>, <Zn>.<T>
theEmitter->emitIns_R_R_R(INS_sve_fminnmv, EA_8BYTE, REG_V24, REG_P4, REG_V4,
theEmitter->emitIns_R_R_R(INS_sve_fminnmv, EA_SCALABLE, REG_V24, REG_P4, REG_V4,
INS_OPTS_SCALABLE_D); // FMINNMV <V><d>, <Pg>, <Zn>.<T>
theEmitter->emitIns_R_R_R(INS_sve_fminv, EA_4BYTE, REG_V25, REG_P3, REG_V3,
theEmitter->emitIns_R_R_R(INS_sve_fminv, EA_SCALABLE, REG_V25, REG_P3, REG_V3,
INS_OPTS_SCALABLE_S); // FMINV <V><d>, <Pg>, <Zn>.<T>

// IF_SVE_HQ_3A
Expand Down
29 changes: 23 additions & 6 deletions src/coreclr/jit/emitarm64sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3060,7 +3060,6 @@ void emitter::emitInsSve_R_R_R(instruction ins,
break;

case INS_sve_saddv:
case INS_sve_uaddv:
assert(isFloatReg(reg1));
assert(isLowPredicateRegister(reg2));
assert(isVectorRegister(reg3));
Expand All @@ -3069,6 +3068,15 @@ void emitter::emitInsSve_R_R_R(instruction ins,
fmt = IF_SVE_AI_3A;
break;

case INS_sve_uaddv:
assert(isFloatReg(reg1));
assert(isLowPredicateRegister(reg2));
assert(isVectorRegister(reg3));
assert(insOptsScalableStandard(opt));
assert(insScalableOptsNone(sopt));
fmt = IF_SVE_AI_3A;
break;

case INS_sve_addqv:
unreached(); // TODO-SVE: Not yet supported.
assert(isVectorRegister(reg1));
Expand Down Expand Up @@ -4059,7 +4067,7 @@ void emitter::emitInsSve_R_R_R(instruction ins,
assert(isLowPredicateRegister(reg2));
assert(isVectorRegister(reg3));
assert(insOptsScalableFloat(opt));
assert(isValidVectorElemsizeSveFloat(size));
assert(isScalableVectorSize(size));
assert(insScalableOptsNone(sopt));
fmt = IF_SVE_HE_3A;
break;
Expand All @@ -4069,7 +4077,7 @@ void emitter::emitInsSve_R_R_R(instruction ins,
assert(isLowPredicateRegister(reg2));
assert(isVectorRegister(reg3));
assert(insOptsScalableFloat(opt));
assert(isValidVectorElemsizeSveFloat(size));
assert(isScalableVectorSize(size));
assert(insScalableOptsNone(sopt));
fmt = IF_SVE_HJ_3A;
break;
Expand Down Expand Up @@ -12618,7 +12626,7 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
assert(isVectorRegister(id->idReg1())); // ddddd
assert(isLowPredicateRegister(id->idReg2())); // ggg
assert(isVectorRegister(id->idReg3())); // mmmmm
assert(isValidVectorElemsizeSveFloat(id->idOpSize()));
assert(isScalableVectorSize(id->idOpSize()));
break;

// Scalable to general register.
Expand Down Expand Up @@ -13211,11 +13219,20 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)

// Scalable, widening to scalar SIMD.
case IF_SVE_AI_3A: // ........xx...... ...gggnnnnnddddd -- SVE integer add reduction (predicated)
assert(insOptsScalableWide(id->idInsOpt())); // xx
switch (id->idIns())
{
case INS_sve_saddv:
assert(insOptsScalableWide(id->idInsOpt())); // xx
break;

default:
assert(insOptsScalableStandard(id->idInsOpt())); // xx
break;
}
assert(isVectorRegister(id->idReg1())); // ddddd
assert(isLowPredicateRegister(id->idReg2())); // ggg
assert(isVectorRegister(id->idReg3())); // mmmmm
assert(isValidVectorElemsizeWidening(id->idOpSize()));
assert(isScalableVectorSize(id->idOpSize()));
break;

// Scalable, possibly FP.
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
// Sve
HARDWARE_INTRINSIC(Sve, Abs, -1, -1, false, {INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_fabs, INS_sve_fabs}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation)
HARDWARE_INTRINSIC(Sve, Add, -1, -1, false, {INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_fadd, INS_sve_fadd}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, AddAcross, -1, 1, true, {INS_sve_saddv, INS_sve_uaddv, INS_sve_saddv, INS_sve_uaddv, INS_sve_saddv, INS_sve_uaddv, INS_sve_uaddv, INS_sve_uaddv, INS_sve_faddv, INS_sve_faddv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation)
HARDWARE_INTRINSIC(Sve, ConditionalSelect, -1, 3, true, {INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_SupportsContainment)
HARDWARE_INTRINSIC(Sve, Count16BitElements, 0, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cnth, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_NoFloatingPointUsed)
HARDWARE_INTRINSIC(Sve, Count32BitElements, 0, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cntw, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_NoFloatingPointUsed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,68 @@ internal Arm64() { }
/// </summary>
public static unsafe Vector<double> Add(Vector<double> left, Vector<double> right) { throw new PlatformNotSupportedException(); }

/// AddAcross : Add reduction

/// <summary>
/// float64_t svaddv[_f64](svbool_t pg, svfloat64_t op)
/// FADDV Dresult, Pg, Zop.D
/// </summary>
public static unsafe Vector<double> AddAcross(Vector<double> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// int64_t svaddv[_s16](svbool_t pg, svint16_t op)
/// SADDV Dresult, Pg, Zop.H
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<short> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// int64_t svaddv[_s32](svbool_t pg, svint32_t op)
/// SADDV Dresult, Pg, Zop.S
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<int> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// int64_t svaddv[_s8](svbool_t pg, svint8_t op)
/// SADDV Dresult, Pg, Zop.B
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<sbyte> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// int64_t svaddv[_s64](svbool_t pg, svint64_t op)
/// UADDV Dresult, Pg, Zop.D
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<long> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// float32_t svaddv[_f32](svbool_t pg, svfloat32_t op)
/// FADDV Sresult, Pg, Zop.S
/// </summary>
public static unsafe Vector<float> AddAcross(Vector<float> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svaddv[_u8](svbool_t pg, svuint8_t op)
/// UADDV Dresult, Pg, Zop.B
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<byte> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svaddv[_u16](svbool_t pg, svuint16_t op)
/// UADDV Dresult, Pg, Zop.H
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<ushort> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svaddv[_u32](svbool_t pg, svuint32_t op)
/// UADDV Dresult, Pg, Zop.S
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<uint> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// uint64_t svaddv[_u64](svbool_t pg, svuint64_t op)
/// UADDV Dresult, Pg, Zop.D
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<ulong> value) { throw new PlatformNotSupportedException(); }


/// ConditionalSelect : Conditionally select elements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,69 @@ internal Arm64() { }
public static unsafe Vector<double> Add(Vector<double> left, Vector<double> right) => Add(left, right);


/// AddAcross : Add reduction

/// <summary>
/// float64_t svaddv[_f64](svbool_t pg, svfloat64_t op)
/// FADDV Dresult, Pg, Zop.D
/// </summary>
public static unsafe Vector<double> AddAcross(Vector<double> value) => AddAcross(value);

/// <summary>
/// int64_t svaddv[_s16](svbool_t pg, svint16_t op)
/// SADDV Dresult, Pg, Zop.H
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<short> value) => AddAcross(value);

/// <summary>
/// int64_t svaddv[_s32](svbool_t pg, svint32_t op)
/// SADDV Dresult, Pg, Zop.S
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<int> value) => AddAcross(value);

/// <summary>
/// int64_t svaddv[_s8](svbool_t pg, svint8_t op)
/// SADDV Dresult, Pg, Zop.B
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<sbyte> value) => AddAcross(value);

/// <summary>
/// int64_t svaddv[_s64](svbool_t pg, svint64_t op)
/// UADDV Dresult, Pg, Zop.D
/// </summary>
public static unsafe Vector<long> AddAcross(Vector<long> value) => AddAcross(value);

/// <summary>
/// float32_t svaddv[_f32](svbool_t pg, svfloat32_t op)
/// FADDV Sresult, Pg, Zop.S
/// </summary>
public static unsafe Vector<float> AddAcross(Vector<float> value) => AddAcross(value);

/// <summary>
/// uint64_t svaddv[_u8](svbool_t pg, svuint8_t op)
/// UADDV Dresult, Pg, Zop.B
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<byte> value) => AddAcross(value);

/// <summary>
/// uint64_t svaddv[_u16](svbool_t pg, svuint16_t op)
/// UADDV Dresult, Pg, Zop.H
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<ushort> value) => AddAcross(value);

/// <summary>
/// uint64_t svaddv[_u32](svbool_t pg, svuint32_t op)
/// UADDV Dresult, Pg, Zop.S
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<uint> value) => AddAcross(value);

/// <summary>
/// uint64_t svaddv[_u64](svbool_t pg, svuint64_t op)
/// UADDV Dresult, Pg, Zop.D
/// </summary>
public static unsafe Vector<ulong> AddAcross(Vector<ulong> value) => AddAcross(value);


/// ConditionalSelect : Conditionally select elements

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4187,6 +4187,16 @@ internal Arm64() { }
public static System.Numerics.Vector<ulong> Add(System.Numerics.Vector<ulong> left, System.Numerics.Vector<ulong> right) { throw null; }
public static System.Numerics.Vector<float> Add(System.Numerics.Vector<float> left, System.Numerics.Vector<float> right) { throw null; }
public static System.Numerics.Vector<double> Add(System.Numerics.Vector<double> left, System.Numerics.Vector<double> right) { throw null; }
public static System.Numerics.Vector<double> AddAcross(System.Numerics.Vector<double> value) { throw null; }
public static System.Numerics.Vector<long> AddAcross(System.Numerics.Vector<short> value) { throw null; }
public static System.Numerics.Vector<long> AddAcross(System.Numerics.Vector<int> value) { throw null; }
public static System.Numerics.Vector<long> AddAcross(System.Numerics.Vector<sbyte> value) { throw null; }
public static System.Numerics.Vector<long> AddAcross(System.Numerics.Vector<long> value) { throw null; }
public static System.Numerics.Vector<float> AddAcross(System.Numerics.Vector<float> value) { throw null; }
public static System.Numerics.Vector<ulong> AddAcross(System.Numerics.Vector<byte> value) { throw null; }
public static System.Numerics.Vector<ulong> AddAcross(System.Numerics.Vector<ushort> value) { throw null; }
public static System.Numerics.Vector<ulong> AddAcross(System.Numerics.Vector<uint> value) { throw null; }
public static System.Numerics.Vector<ulong> AddAcross(System.Numerics.Vector<ulong> value) { throw null; }
public static ulong Count16BitElements([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static ulong Count32BitElements([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static ulong Count64BitElements([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
Expand Down
Loading
Loading