Skip to content

Commit e012fd4

Browse files
Minor cleanup of the Vector64/128/256/512 implementations to improve fallbacks (#103095)
* Minor cleanup of the Vector64/128/256/512 implementations to improve fallbacks * Ensure gtNewSimdSumNode maintains consistency with the software fallback * Ensure Vector128.Sum also does pairwise adds for floating-point * Use the right type in the gtNewSimdBinOpNode call * Don't regress fallback scenarios using AndNot
1 parent 2dba5a3 commit e012fd4

File tree

10 files changed

+355
-314
lines changed

10 files changed

+355
-314
lines changed

src/coreclr/jit/gentree.cpp

+38-8
Original file line numberDiff line numberDiff line change
@@ -25510,20 +25510,48 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
2551025510
{
2551125511
assert(IsBaselineVector512IsaSupportedDebugOnly());
2551225512
GenTree* op1Dup = fgMakeMultiUse(&op1);
25513-
op1 = gtNewSimdGetUpperNode(TYP_SIMD32, op1, simdBaseJitType, simdSize);
25514-
op1Dup = gtNewSimdGetLowerNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize);
25515-
simdSize = simdSize / 2;
25516-
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseJitType, simdSize);
25513+
25514+
op1 = gtNewSimdGetLowerNode(TYP_SIMD32, op1, simdBaseJitType, simdSize);
25515+
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize);
25516+
25517+
if (varTypeIsFloating(simdBaseType))
25518+
{
25519+
// We need to ensure deterministic results which requires
25520+
// consistently adding values together. Since many operations
25521+
// end up operating on 128-bit lanes, we break sum the same way.
25522+
25523+
op1 = gtNewSimdSumNode(type, op1, simdBaseJitType, 32);
25524+
op1Dup = gtNewSimdSumNode(type, op1Dup, simdBaseJitType, 32);
25525+
25526+
return gtNewOperNode(GT_ADD, type, op1, op1Dup);
25527+
}
25528+
25529+
simdSize = 32;
25530+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseJitType, 32);
2551725531
}
2551825532

2551925533
if (simdSize == 32)
2552025534
{
2552125535
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));
2552225536
GenTree* op1Dup = fgMakeMultiUse(&op1);
25523-
op1 = gtNewSimdGetUpperNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);
25524-
op1Dup = gtNewSimdGetLowerNode(TYP_SIMD16, op1Dup, simdBaseJitType, simdSize);
25525-
simdSize = simdSize / 2;
25526-
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseJitType, simdSize);
25537+
25538+
op1 = gtNewSimdGetLowerNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);
25539+
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD16, op1Dup, simdBaseJitType, simdSize);
25540+
25541+
if (varTypeIsFloating(simdBaseType))
25542+
{
25543+
// We need to ensure deterministic results which requires
25544+
// consistently adding values together. Since many operations
25545+
// end up operating on 128-bit lanes, we break sum the same way.
25546+
25547+
op1 = gtNewSimdSumNode(type, op1, simdBaseJitType, 16);
25548+
op1Dup = gtNewSimdSumNode(type, op1Dup, simdBaseJitType, 16);
25549+
25550+
return gtNewOperNode(GT_ADD, type, op1, op1Dup);
25551+
}
25552+
25553+
simdSize = 16;
25554+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseJitType, 16);
2552725555
}
2552825556

2552925557
assert(simdSize == 16);
@@ -25534,6 +25562,7 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
2553425562
{
2553525563
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
2553625564
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
25565+
2553725566
if (compOpportunisticallyDependsOn(InstructionSet_AVX))
2553825567
{
2553925568
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
@@ -25571,6 +25600,7 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
2557125600
{
2557225601
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
2557325602
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
25603+
2557425604
if (compOpportunisticallyDependsOn(InstructionSet_AVX))
2557525605
{
2557625606
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));

src/libraries/System.Private.CoreLib/src/System/Numerics/Vector.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ public static Vector<float> Ceiling(Vector<float> value)
253253
/// <returns>A vector whose bits come from <paramref name="left" /> or <paramref name="right" /> based on the value of <paramref name="condition" />.</returns>
254254
[Intrinsic]
255255
[MethodImpl(MethodImplOptions.AggressiveInlining)]
256-
public static Vector<T> ConditionalSelect<T>(Vector<T> condition, Vector<T> left, Vector<T> right) => (left & condition) | (right & ~condition);
256+
public static Vector<T> ConditionalSelect<T>(Vector<T> condition, Vector<T> left, Vector<T> right) => (left & condition) | AndNot(right, condition);
257257

258258
/// <summary>Conditionally selects a value from two vectors on a bitwise basis.</summary>
259259
/// <param name="condition">The mask that is used to select a value from <paramref name="left" /> or <paramref name="right" />.</param>
@@ -1186,7 +1186,7 @@ public static Vector<T> Min<T>(Vector<T> left, Vector<T> right)
11861186
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
11871187
/// <returns>The product of <paramref name="left" /> and <paramref name="right" />.</returns>
11881188
[Intrinsic]
1189-
public static Vector<T> Multiply<T>(T left, Vector<T> right) => left * right;
1189+
public static Vector<T> Multiply<T>(T left, Vector<T> right) => right * left;
11901190

11911191
/// <inheritdoc cref="Vector128.MultiplyAddEstimate(Vector128{double}, Vector128{double}, Vector128{double})" />
11921192
[Intrinsic]

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Vector128.cs

+65-44
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,21 @@ public static bool IsHardwareAccelerated
5858
[MethodImpl(MethodImplOptions.AggressiveInlining)]
5959
public static Vector128<T> Abs<T>(Vector128<T> vector)
6060
{
61-
return Create(
62-
Vector64.Abs(vector._lower),
63-
Vector64.Abs(vector._upper)
64-
);
61+
if ((typeof(T) == typeof(byte))
62+
|| (typeof(T) == typeof(ushort))
63+
|| (typeof(T) == typeof(uint))
64+
|| (typeof(T) == typeof(ulong))
65+
|| (typeof(T) == typeof(nuint)))
66+
{
67+
return vector;
68+
}
69+
else
70+
{
71+
return Create(
72+
Vector64.Abs(vector._lower),
73+
Vector64.Abs(vector._upper)
74+
);
75+
}
6576
}
6677

6778
/// <summary>Adds two vectors to compute their sum.</summary>
@@ -80,13 +91,7 @@ public static Vector128<T> Abs<T>(Vector128<T> vector)
8091
/// <returns>The bitwise-and of <paramref name="left" /> and the ones-complement of <paramref name="right" />.</returns>
8192
[Intrinsic]
8293
[MethodImpl(MethodImplOptions.AggressiveInlining)]
83-
public static Vector128<T> AndNot<T>(Vector128<T> left, Vector128<T> right)
84-
{
85-
return Create(
86-
Vector64.AndNot(left._lower, right._lower),
87-
Vector64.AndNot(left._upper, right._upper)
88-
);
89-
}
94+
public static Vector128<T> AndNot<T>(Vector128<T> left, Vector128<T> right) => left & ~right;
9095

9196
/// <summary>Reinterprets a <see cref="Vector128{TFrom}" /> as a new <see cref="Vector128{TTo}" />.</summary>
9297
/// <typeparam name="TFrom">The type of the elements in the input vector.</typeparam>
@@ -377,10 +382,26 @@ public static Vector<T> AsVector<T>(this Vector128<T> value)
377382
[MethodImpl(MethodImplOptions.AggressiveInlining)]
378383
internal static Vector128<T> Ceiling<T>(Vector128<T> vector)
379384
{
380-
return Create(
381-
Vector64.Ceiling(vector._lower),
382-
Vector64.Ceiling(vector._upper)
383-
);
385+
if ((typeof(T) == typeof(byte))
386+
|| (typeof(T) == typeof(short))
387+
|| (typeof(T) == typeof(int))
388+
|| (typeof(T) == typeof(long))
389+
|| (typeof(T) == typeof(nint))
390+
|| (typeof(T) == typeof(nuint))
391+
|| (typeof(T) == typeof(sbyte))
392+
|| (typeof(T) == typeof(ushort))
393+
|| (typeof(T) == typeof(uint))
394+
|| (typeof(T) == typeof(ulong)))
395+
{
396+
return vector;
397+
}
398+
else
399+
{
400+
return Create(
401+
Vector64.Ceiling(vector._lower),
402+
Vector64.Ceiling(vector._upper)
403+
);
404+
}
384405
}
385406

386407
/// <summary>Computes the ceiling of each element in a vector.</summary>
@@ -406,13 +427,7 @@ internal static Vector128<T> Ceiling<T>(Vector128<T> vector)
406427
/// <exception cref="NotSupportedException">The type of <paramref name="condition" />, <paramref name="left" />, and <paramref name="right" /> (<typeparamref name="T" />) is not supported.</exception>
407428
[Intrinsic]
408429
[MethodImpl(MethodImplOptions.AggressiveInlining)]
409-
public static Vector128<T> ConditionalSelect<T>(Vector128<T> condition, Vector128<T> left, Vector128<T> right)
410-
{
411-
return Create(
412-
Vector64.ConditionalSelect(condition._lower, left._lower, right._lower),
413-
Vector64.ConditionalSelect(condition._upper, left._upper, right._upper)
414-
);
415-
}
430+
public static Vector128<T> ConditionalSelect<T>(Vector128<T> condition, Vector128<T> left, Vector128<T> right) => (left & condition) | AndNot(right, condition);
416431

417432
/// <summary>Converts a <see cref="Vector128{Int64}" /> to a <see cref="Vector128{Double}" />.</summary>
418433
/// <param name="vector">The vector to convert.</param>
@@ -1413,16 +1428,7 @@ public static Vector128<T> CreateScalarUnsafe<T>(T value)
14131428
/// <exception cref="NotSupportedException">The type of <paramref name="left" /> and <paramref name="right" /> (<typeparamref name="T" />) is not supported.</exception>
14141429
[Intrinsic]
14151430
[MethodImpl(MethodImplOptions.AggressiveInlining)]
1416-
public static T Dot<T>(Vector128<T> left, Vector128<T> right)
1417-
{
1418-
// Doing this as Dot(lower) + Dot(upper) is important for floating-point determinism
1419-
// This is because the underlying dpps instruction on x86/x64 will do this equivalently
1420-
// and otherwise the software vs accelerated implementations may differ in returned result.
1421-
1422-
T result = Vector64.Dot(left._lower, right._lower);
1423-
result = Scalar<T>.Add(result, Vector64.Dot(left._upper, right._upper));
1424-
return result;
1425-
}
1431+
public static T Dot<T>(Vector128<T> left, Vector128<T> right) => Sum(left * right);
14261432

14271433
/// <summary>Compares two vectors to determine if they are equal on a per-element basis.</summary>
14281434
/// <typeparam name="T">The type of the elements in the vector.</typeparam>
@@ -1519,10 +1525,26 @@ public static uint ExtractMostSignificantBits<T>(this Vector128<T> vector)
15191525
[MethodImpl(MethodImplOptions.AggressiveInlining)]
15201526
internal static Vector128<T> Floor<T>(Vector128<T> vector)
15211527
{
1522-
return Create(
1523-
Vector64.Floor(vector._lower),
1524-
Vector64.Floor(vector._upper)
1525-
);
1528+
if ((typeof(T) == typeof(byte))
1529+
|| (typeof(T) == typeof(short))
1530+
|| (typeof(T) == typeof(int))
1531+
|| (typeof(T) == typeof(long))
1532+
|| (typeof(T) == typeof(nint))
1533+
|| (typeof(T) == typeof(nuint))
1534+
|| (typeof(T) == typeof(sbyte))
1535+
|| (typeof(T) == typeof(ushort))
1536+
|| (typeof(T) == typeof(uint))
1537+
|| (typeof(T) == typeof(ulong)))
1538+
{
1539+
return vector;
1540+
}
1541+
else
1542+
{
1543+
return Create(
1544+
Vector64.Floor(vector._lower),
1545+
Vector64.Floor(vector._upper)
1546+
);
1547+
}
15261548
}
15271549

15281550
/// <summary>Computes the floor of each element in a vector.</summary>
@@ -1989,7 +2011,7 @@ public static Vector128<T> Min<T>(Vector128<T> left, Vector128<T> right)
19892011
/// <returns>The product of <paramref name="left" /> and <paramref name="right" />.</returns>
19902012
/// <exception cref="NotSupportedException">The type of <paramref name="left" /> and <paramref name="right"/> (<typeparamref name="T" />) is not supported.</exception>
19912013
[Intrinsic]
1992-
public static Vector128<T> Multiply<T>(T left, Vector128<T> right) => left * right;
2014+
public static Vector128<T> Multiply<T>(T left, Vector128<T> right) => right * left;
19932015

19942016
/// <inheritdoc cref="Vector64.MultiplyAddEstimate(Vector64{double}, Vector64{double}, Vector64{double})" />
19952017
[Intrinsic]
@@ -2735,14 +2757,13 @@ public static void StoreUnsafe<T>(this Vector128<T> source, ref T destination, n
27352757
[MethodImpl(MethodImplOptions.AggressiveInlining)]
27362758
public static T Sum<T>(Vector128<T> vector)
27372759
{
2738-
T sum = default!;
2739-
2740-
for (int index = 0; index < Vector128<T>.Count; index++)
2741-
{
2742-
sum = Scalar<T>.Add(sum, vector.GetElementUnsafe(index));
2743-
}
2760+
// Doing this as Sum(lower) + Sum(upper) is important for floating-point determinism
2761+
// This is because the underlying dpps instruction on x86/x64 will do this equivalently
2762+
// and otherwise the software vs accelerated implementations may differ in returned result.
27442763

2745-
return sum;
2764+
T result = Vector64.Sum(vector._lower);
2765+
result = Scalar<T>.Add(result, Vector64.Sum(vector._upper));
2766+
return result;
27462767
}
27472768

27482769
/// <summary>Converts the given vector to a scalar containing the value of the first element.</summary>

0 commit comments

Comments
 (0)