Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector.Sum(Vector<T>) API implementation for horizontal add. #53527

Merged
merged 14 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/coreclr/jit/simdashwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,64 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
}
break;
}
case NI_VectorT128_Sum:
{
if (compOpportunisticallyDependsOn(InstructionSet_SSSE3))
{
GenTree* tmp;
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength);

for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, NI_SSSE3_HorizontalAdd,
simdBaseJitType, simdSize);
}

return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType,
simdSize);
}

return nullptr;
}
case NI_VectorT256_Sum:
{
// HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together.
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength) - 1; // Minus 1 because for the last pass we split the vector
// to low / high and add them together.
GenTree* tmp;
NamedIntrinsic horizontalAdd = NI_AVX2_HorizontalAdd;
NamedIntrinsic add = NI_SSE2_Add;

if (simdBaseType == TYP_DOUBLE)
{
horizontalAdd = NI_AVX_HorizontalAdd;
}
else if (simdBaseType == TYP_FLOAT)
{
horizontalAdd = NI_AVX_HorizontalAdd;
add = NI_SSE_Add;
}

for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, horizontalAdd, simdBaseJitType, simdSize);
}

op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode(0x01, TYP_INT),
NI_AVX_ExtractVector128, simdBaseJitType, simdSize);

op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, tmp, add, simdBaseJitType, 16);

return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType, 16);
}
#elif defined(TARGET_ARM64)
case NI_VectorT128_Abs:
{
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/simdashwintrinsiclistarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Inequality,
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Multiply, 2, {NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_Illegal, NI_Illegal, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Subtraction, 2, {NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Arm64_Subtract}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, SquareRoot, 1, {NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_AdvSimd_Arm64_Sqrt, NI_AdvSimd_Arm64_Sqrt}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, Sum, 1, {NI_AdvSimd_Arm64_AddAcross, NI_AdvSimd_Arm64_AddAcross, NI_AdvSimd_Arm64_AddAcross, NI_AdvSimd_Arm64_AddAcross, NI_AdvSimd_Arm64_AddAcross, NI_AdvSimd_Arm64_AddAcross, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal}, SimdAsHWIntrinsicFlag::None)

#undef SIMD_AS_HWINTRINSIC_NM
#undef SIMD_AS_HWINTRINSIC_ID
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/simdashwintrinsiclistxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Inequality,
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Multiply, 2, {NI_Illegal, NI_Illegal, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_Illegal, NI_Illegal, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Subtraction, 2, {NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE_Subtract, NI_SSE2_Subtract}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, SquareRoot, 1, {NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_SSE_Sqrt, NI_SSE2_Sqrt}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, Sum, 1, {NI_Illegal, NI_Illegal, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_Illegal, NI_Illegal, NI_VectorT128_Sum, NI_VectorT128_Sum}, SimdAsHWIntrinsicFlag::None)

// *************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
// ISA ID Name NumArg Instructions Flags
Expand Down Expand Up @@ -170,6 +171,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT256, op_Inequality,
SIMD_AS_HWINTRINSIC_ID(VectorT256, op_Multiply, 2, {NI_Illegal, NI_Illegal, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply, NI_Illegal, NI_Illegal, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT256, op_Subtraction, 2, {NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX_Subtract, NI_AVX_Subtract}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT256, SquareRoot, 1, {NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_AVX_Sqrt, NI_AVX_Sqrt}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT256, Sum, 1, {NI_Illegal, NI_Illegal, NI_VectorT256_Sum, NI_VectorT256_Sum, NI_VectorT256_Sum, NI_VectorT256_Sum, NI_Illegal, NI_Illegal, NI_VectorT256_Sum, NI_VectorT256_Sum}, SimdAsHWIntrinsicFlag::None)

#undef SIMD_AS_HWINTRINSIC_NM
#undef SIMD_AS_HWINTRINSIC_ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ public static partial class Vector
[System.CLSCompliantAttribute(false)]
public static void Widen(System.Numerics.Vector<System.UInt32> source, out System.Numerics.Vector<System.UInt64> low, out System.Numerics.Vector<System.UInt64> high) { throw null; }
public static System.Numerics.Vector<T> Xor<T>(System.Numerics.Vector<T> left, System.Numerics.Vector<T> right) where T : struct { throw null; }
public static T Sum<T>(System.Numerics.Vector<T> value) where T : struct { throw null; }
}
public partial struct Vector2 : System.IEquatable<System.Numerics.Vector2>, System.IFormattable
{
Expand Down
43 changes: 43 additions & 0 deletions src/libraries/System.Numerics.Vectors/tests/GenericVectorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3137,6 +3137,49 @@ private unsafe void TestAs<TFrom, TTo>() where TFrom : unmanaged where TTo : unm
}
#endregion

#region Sum

[Fact]
public void SumInt32() => TestSum<int>(x => x.Aggregate((a, b) => a + b));

[Fact]
public void SumInt64() => TestSum<long>(x => x.Aggregate((a, b) => a + b));

[Fact]
public void SumSingle() => TestSum<float>(x => x.Aggregate((a, b) => a + b));

[Fact]
public void SumDouble() => TestSum<double>(x => x.Aggregate((a, b) => a + b));

[Fact]
public void SumUInt32() => TestSum<uint>(x => x.Aggregate((a, b) => a + b));

[Fact]
public void SumUInt64() => TestSum<ulong>(x => x.Aggregate((a, b) => a + b));

[Fact]
public void SumByte() => TestSum<byte>(x => x.Aggregate((a, b) => (byte)(a + b)));

[Fact]
public void SumSByte() => TestSum<sbyte>(x => x.Aggregate((a, b) => (sbyte)(a + b)));

[Fact]
public void SumInt16() => TestSum<short>(x => x.Aggregate((a, b) => (short)(a + b)));

[Fact]
public void SumUInt16() => TestSum<ushort>(x => x.Aggregate((a, b) => (ushort)(a + b)));

private static void TestSum<T>(Func<T[], T> expected) where T : struct, IEquatable<T>
{
T[] values = GenerateRandomValuesForVector<T>();
Vector<T> vector = new(values);
T sum = Vector.Sum(vector);

AssertEqual(expected(values), sum, "Sum");
}

#endregion

#region Helper Methods
private static void AssertEqual<T>(T expected, T actual, string operation, int precision = -1) where T : IEquatable<T>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1292,5 +1292,14 @@ public static Vector<TTo> As<TFrom, TTo>(this Vector<TFrom> vector)

return Unsafe.As<Vector<TFrom>, Vector<TTo>>(ref vector);
}

/// <summary>
/// Returns the sum of all elements inside the vector.
/// </summary>
[Intrinsic]
public static T Sum<T>(Vector<T> value) where T : struct
{
return Vector<T>.Sum(value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,19 @@ internal static T Dot(Vector<T> left, Vector<T> right)
return product;
}

[Intrinsic]
internal static T Sum(Vector<T> value)
{
T sum = default;

for (nint index = 0; index < Count; index++)
{
sum = ScalarAdd(sum, value.GetElement(index));
}

return sum;
}

[Intrinsic]
internal static unsafe Vector<T> SquareRoot(Vector<T> value)
{
Expand Down