From 3bf40a378f00cb5bf18ff62796bc7097719b974c Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 28 Sep 2023 08:51:48 -0400 Subject: [PATCH] Update TensorPrimitives aggregations to vectorize handling of remaining elements (#92672) * Update TensorPrimitives.CosineSimilarity to vectorize handling of remaining elements * Vectorize remainder handling for Aggregate helpers --- .../Numerics/Tensors/TensorPrimitives.cs | 44 +- .../Tensors/TensorPrimitives.netcore.cs | 454 ++++++++++++------ .../Tensors/TensorPrimitives.netstandard.cs | 150 ++++-- 3 files changed, 443 insertions(+), 205 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 48ff55b4f0bd2..d5f7382046887 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -126,7 +126,7 @@ public static float Distance(ReadOnlySpan x, ReadOnlySpan y) ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } - return MathF.Sqrt(Aggregate(0f, x, y)); + return MathF.Sqrt(Aggregate(x, y)); } /// Computes the element-wise result of: / . @@ -162,7 +162,7 @@ public static float Dot(ReadOnlySpan x, ReadOnlySpan y) // BLAS1: ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } - return Aggregate(0f, x, y); + return Aggregate(x, y); } /// Computes the element-wise result of: pow(e, ). @@ -545,7 +545,7 @@ public static void Negate(ReadOnlySpan x, Span destination) => /// The first tensor, represented as a span. /// The L2 norm. public static float Norm(ReadOnlySpan x) => // BLAS1: nrm2 - MathF.Sqrt(Aggregate(0f, x)); + MathF.Sqrt(Aggregate(x)); /// Computes the product of all elements in . /// The tensor, represented as a span. @@ -558,7 +558,7 @@ public static float Product(ReadOnlySpan x) ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - return Aggregate(1.0f, x); + return Aggregate(x); } /// Computes the product of the element-wise result of: - . @@ -580,7 +580,7 @@ public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan(1.0f, x, y); + return Aggregate(x, y); } /// Computes the product of the element-wise result of: + . @@ -602,7 +602,7 @@ public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } - return Aggregate(1.0f, x, y); + return Aggregate(x, y); } /// @@ -703,7 +703,7 @@ public static void Subtract(ReadOnlySpan x, float y, Span destinat /// The tensor, represented as a span. /// The result of adding all elements in , or zero if is empty. public static float Sum(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); /// Computes the sum of the absolute values of every element in . /// The tensor, represented as a span. @@ -713,14 +713,14 @@ public static float Sum(ReadOnlySpan x) => /// This method corresponds to the asum method defined by BLAS1. /// public static float SumOfMagnitudes(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); /// Computes the sum of the squares of every element in . /// The tensor, represented as a span. /// The result of adding every element in multiplied by itself, or zero if is empty. /// This method effectively does .Sum(.Multiply(, )). public static float SumOfSquares(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); /// Computes the element-wise result of: tanh(). /// The tensor, represented as a span. @@ -739,5 +739,31 @@ public static void Tanh(ReadOnlySpan x, Span destination) destination[i] = MathF.Tanh(x[i]); } } + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + private static ReadOnlySpan RemainderUInt32Mask_16x16 => new uint[] + { + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + }; } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index e82f67b1f27d5..d210d9f0f8240 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -58,12 +58,6 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector512.Count) { @@ -76,6 +70,7 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan.Count; + int i = 0; do { Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); @@ -89,13 +84,28 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); + + Vector512 remainderMask = LoadRemainderMaskSingleVector512(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector512.Sum(dotProductVector) / + (MathF.Sqrt(Vector512.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector512.Sum(ySumOfSquaresVector))); } - else #endif + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); @@ -107,6 +117,7 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan.Count; + int i = 0; do { Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); @@ -120,12 +131,28 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); + + Vector256 remainderMask = LoadRemainderMaskSingleVector256(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector256.Sum(dotProductVector) / + (MathF.Sqrt(Vector256.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector256.Sum(ySumOfSquaresVector))); } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); @@ -136,6 +163,7 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan.Count; + int i = 0; do { Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); @@ -149,14 +177,31 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); + + Vector128 remainderMask = LoadRemainderMaskSingleVector128(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector128.Sum(dotProductVector) / + (MathF.Sqrt(Vector128.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector128.Sum(ySumOfSquaresVector))); } - // Process any remaining elements past the last vector. - for (; (uint)i < (uint)x.Length; i++) + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. + float dotProduct = 0f, xSumOfSquares = 0f, ySumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) { dotProduct = MathF.FusedMultiplyAdd(x[i], y[i], dotProduct); xSumOfSquares = MathF.FusedMultiplyAdd(x[i], x[i], xSumOfSquares); @@ -164,187 +209,256 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan( - float identityValue, ReadOnlySpan x) + ReadOnlySpan x) where TLoad : struct, IUnaryOperator - where TAggregate : struct, IBinaryOperator + where TAggregate : struct, IAggregationOperator { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + if (x.Length == 0) + { + return 0; + } + + ref float xRef = ref MemoryMarshal.GetReference(x); #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) { - ref float xRef = ref MemoryMarshal.GetReference(x); - // Load the first vector as the initial set of results - Vector512 resultVector = TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, 0)); + Vector512 result = TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, 0)); int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector512.Count; - do + while (i <= oneVectorFromEnd) { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i))); + result = TAggregate.Invoke(result, TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i))); i += Vector512.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + result = TAggregate.Invoke(result, + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.Create(TAggregate.IdentityValue), + TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count))))); + } // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + return TAggregate.Invoke(result); } - else #endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { // Load the first vector as the initial set of results - Vector256 resultVector = TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, 0)); + Vector256 result = TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, 0)); int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector256.Count; - do + while (i <= oneVectorFromEnd) { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i))); + result = TAggregate.Invoke(result, TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i))); i += Vector256.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + result = TAggregate.Invoke(result, + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.Create(TAggregate.IdentityValue), + TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count))))); + } // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + return TAggregate.Invoke(result); } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { // Load the first vector as the initial set of results - Vector128 resultVector = TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, 0)); + Vector128 result = TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, 0)); int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector128.Count; - do + while (i <= oneVectorFromEnd) { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i))); + result = TAggregate.Invoke(result, TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i))); i += Vector128.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + result = TAggregate.Invoke(result, + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.Create(TAggregate.IdentityValue), + TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count))))); + } // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + return TAggregate.Invoke(result); } - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. { - result = TAggregate.Invoke(result, TLoad.Invoke(x[i])); - } + float result = TLoad.Invoke(x[0]); + for (int i = 1; i < x.Length; i++) + { + result = TAggregate.Invoke(result, TLoad.Invoke(x[i])); + } - return result; + return result; + } } private static float Aggregate( - float identityValue, ReadOnlySpan x, ReadOnlySpan y) + ReadOnlySpan x, ReadOnlySpan y) where TBinary : struct, IBinaryOperator - where TAggregate : struct, IBinaryOperator + where TAggregate : struct, IAggregationOperator { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + Debug.Assert(x.Length == y.Length); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) + if (x.IsEmpty) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + return 0; + } + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { // Load the first vector as the initial set of results - Vector512 resultVector = TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, 0), Vector512.LoadUnsafe(ref yRef, 0)); + Vector512 result = TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, 0), Vector512.LoadUnsafe(ref yRef, 0)); int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector512.Count; - do + while (i <= oneVectorFromEnd) { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), Vector512.LoadUnsafe(ref yRef, (uint)i))); + result = TAggregate.Invoke(result, TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), Vector512.LoadUnsafe(ref yRef, (uint)i))); i += Vector512.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the spans, masking off elements already processed. + if (i != x.Length) + { + result = TAggregate.Invoke(result, + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.Create(TAggregate.IdentityValue), + TBinary.Invoke( + Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count))))); + } // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + return TAggregate.Invoke(result); } - else #endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { // Load the first vector as the initial set of results - Vector256 resultVector = TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, 0), Vector256.LoadUnsafe(ref yRef, 0)); + Vector256 result = TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, 0), Vector256.LoadUnsafe(ref yRef, 0)); int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector256.Count; - do + while (i <= oneVectorFromEnd) { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), Vector256.LoadUnsafe(ref yRef, (uint)i))); + result = TAggregate.Invoke(result, TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), Vector256.LoadUnsafe(ref yRef, (uint)i))); i += Vector256.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the spans, masking off elements already processed. + if (i != x.Length) + { + result = TAggregate.Invoke(result, + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.Create(TAggregate.IdentityValue), + TBinary.Invoke( + Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count))))); + } // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + return TAggregate.Invoke(result); } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { // Load the first vector as the initial set of results - Vector128 resultVector = TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, 0), Vector128.LoadUnsafe(ref yRef, 0)); + Vector128 result = TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, 0), Vector128.LoadUnsafe(ref yRef, 0)); int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector128.Count; - do + while (i <= oneVectorFromEnd) { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), Vector128.LoadUnsafe(ref yRef, (uint)i))); + result = TAggregate.Invoke(result, TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), Vector128.LoadUnsafe(ref yRef, (uint)i))); i += Vector128.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the spans, masking off elements already processed. + if (i != x.Length) + { + result = TAggregate.Invoke(result, + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.Create(TAggregate.IdentityValue), + TBinary.Invoke( + Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count))))); + } // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + return TAggregate.Invoke(result); } - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. { - result = TAggregate.Invoke(result, TBinary.Invoke(x[i], y[i])); - } + float result = TBinary.Invoke(xRef, yRef); + for (int i = 1; i < x.Length; i++) + { + result = TAggregate.Invoke(result, + TBinary.Invoke( + Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i))); + } - return result; + return result; + } } /// - /// This is the same as , + /// This is the same as , /// except it early exits on NaN. /// - private static float MinMaxCore(ReadOnlySpan x) where TMinMax : struct, IBinaryOperator + private static float MinMaxCore(ReadOnlySpan x) where TMinMax : struct, IAggregationOperator { if (x.IsEmpty) { @@ -356,28 +470,24 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : // otherwise returns the greater of the inputs. // It treats +0 as greater than -0 as per the specification. - // Initialize the result to the identity value - float result = x[0]; - int i = 0; - #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector512 resultVector = Vector512.LoadUnsafe(ref xRef, 0), current; - if (!Vector512.EqualsAll(resultVector, resultVector)) + Vector512 result = Vector512.LoadUnsafe(ref xRef, 0), current; + if (!Vector512.EqualsAll(result, result)) { - return GetFirstNaN(resultVector); + return GetFirstNaN(result); } int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - i = Vector512.Count; - do + while (i <= oneVectorFromEnd) { // Load the next vector, and early exit on NaN. current = Vector512.LoadUnsafe(ref xRef, (uint)i); @@ -386,10 +496,9 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - resultVector = TMinMax.Invoke(resultVector, current); + result = TMinMax.Invoke(result, current); i += Vector512.Count; } - while (i <= oneVectorFromEnd); // If any elements remain, handle them in one final vector. if (i != x.Length) @@ -400,31 +509,31 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - resultVector = TMinMax.Invoke(resultVector, current); + result = TMinMax.Invoke(result, current); } // Aggregate the lanes in the vector to create the final scalar result. - return TMinMax.Invoke(resultVector); + return TMinMax.Invoke(result); } #endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector256 resultVector = Vector256.LoadUnsafe(ref xRef, 0), current; - if (!Vector256.EqualsAll(resultVector, resultVector)) + Vector256 result = Vector256.LoadUnsafe(ref xRef, 0), current; + if (!Vector256.EqualsAll(result, result)) { - return GetFirstNaN(resultVector); + return GetFirstNaN(result); } int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - i = Vector256.Count; - do + while (i <= oneVectorFromEnd) { // Load the next vector, and early exit on NaN. current = Vector256.LoadUnsafe(ref xRef, (uint)i); @@ -433,10 +542,9 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - resultVector = TMinMax.Invoke(resultVector, current); + result = TMinMax.Invoke(result, current); i += Vector256.Count; } - while (i <= oneVectorFromEnd); // If any elements remain, handle them in one final vector. if (i != x.Length) @@ -447,30 +555,30 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - resultVector = TMinMax.Invoke(resultVector, current); + result = TMinMax.Invoke(result, current); } // Aggregate the lanes in the vector to create the final scalar result. - return TMinMax.Invoke(resultVector); + return TMinMax.Invoke(result); } - if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). - Vector128 resultVector = Vector128.LoadUnsafe(ref xRef, 0), current; - if (!Vector128.EqualsAll(resultVector, resultVector)) + Vector128 result = Vector128.LoadUnsafe(ref xRef, 0), current; + if (!Vector128.EqualsAll(result, result)) { - return GetFirstNaN(resultVector); + return GetFirstNaN(result); } int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - i = Vector128.Count; - do + while (i <= oneVectorFromEnd) { // Load the next vector, and early exit on NaN. current = Vector128.LoadUnsafe(ref xRef, (uint)i); @@ -479,10 +587,9 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - resultVector = TMinMax.Invoke(resultVector, current); + result = TMinMax.Invoke(result, current); i += Vector128.Count; } - while (i <= oneVectorFromEnd); // If any elements remain, handle them in one final vector. if (i != x.Length) @@ -493,26 +600,34 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - resultVector = TMinMax.Invoke(resultVector, current); + result = TMinMax.Invoke(result, current); } // Aggregate the lanes in the vector to create the final scalar result. - return TMinMax.Invoke(resultVector); + return TMinMax.Invoke(result); } // Scalar path used when either vectorization is not supported or the input is too small to vectorize. - for (; (uint)i < (uint)x.Length; i++) { - float current = x[i]; - if (float.IsNaN(current)) + float result = x[0]; + if (float.IsNaN(result)) { - return current; + return result; } - result = TMinMax.Invoke(result, current); - } + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return current; + } + + result = TMinMax.Invoke(result, current); + } - return result; + return result; + } } private static unsafe void InvokeSpanIntoSpan( @@ -1287,7 +1402,27 @@ private static float GetFirstNaN(Vector512 vector) => private static float Log2(float x) => MathF.Log2(x); - private readonly struct AddOperator : IBinaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 LoadRemainderMaskSingleVector128(int validItems) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((validItems * 16) + 12)); // last four floats in the row + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector256 LoadRemainderMaskSingleVector256(int validItems) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((validItems * 16) + 8)); // last eight floats in the row + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector512 LoadRemainderMaskSingleVector512(int validItems) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)(validItems * 16)); // all sixteen floats in the row +#endif + + private readonly struct AddOperator : IAggregationOperator { public static float Invoke(float x, float y) => x + y; public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; @@ -1301,6 +1436,8 @@ private static float GetFirstNaN(Vector512 vector) => #if NET8_0_OR_GREATER public static float Invoke(Vector512 x) => Vector512.Sum(x); #endif + + public static float IdentityValue => 0; } private readonly struct SubtractOperator : IBinaryOperator @@ -1342,7 +1479,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) #endif } - private readonly struct MultiplyOperator : IBinaryOperator + private readonly struct MultiplyOperator : IAggregationOperator { public static float Invoke(float x, float y) => x * y; public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; @@ -1356,6 +1493,8 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) #if NET8_0_OR_GREATER public static float Invoke(Vector512 x) => HorizontalAggregate(x); #endif + + public static float IdentityValue => 1; } private readonly struct DivideOperator : IBinaryOperator @@ -1368,7 +1507,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) #endif } - private readonly struct MaxOperator : IBinaryOperator + private readonly struct MaxOperator : IAggregationOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static float Invoke(float x, float y) => @@ -1457,7 +1596,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) => #endif } - private readonly struct MaxMagnitudeOperator : IBinaryOperator + private readonly struct MaxMagnitudeOperator : IAggregationOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static float Invoke(float x, float y) @@ -1558,7 +1697,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) #endif } - private readonly struct MinOperator : IBinaryOperator + private readonly struct MinOperator : IAggregationOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static float Invoke(float x, float y) => @@ -1647,7 +1786,7 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) => #endif } - private readonly struct MinMagnitudeOperator : IBinaryOperator + private readonly struct MinMagnitudeOperator : IAggregationOperator { [MethodImpl(MethodImplOptions.AggressiveInlining)] public static float Invoke(float x, float y) @@ -1825,14 +1964,17 @@ private interface IBinaryOperator #if NET8_0_OR_GREATER static abstract Vector512 Invoke(Vector512 x, Vector512 y); #endif + } - // Operations for aggregating all lanes in a vector into a single value. - // These are not supported on most implementations. - static virtual float Invoke(Vector128 x) => throw new NotSupportedException(); - static virtual float Invoke(Vector256 x) => throw new NotSupportedException(); + private interface IAggregationOperator : IBinaryOperator + { + static abstract float Invoke(Vector128 x); + static abstract float Invoke(Vector256 x); #if NET8_0_OR_GREATER - static virtual float Invoke(Vector512 x) => throw new NotSupportedException(); + static abstract float Invoke(Vector512 x); #endif + + static virtual float IdentityValue => throw new NotSupportedException(); } private interface ITernaryOperator diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index a19f2529ab99c..e05e54bcad769 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -18,9 +19,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector.Count) + if (Vector.IsHardwareAccelerated && + Vector.Count <= 16 && // currently never greater than 8, but 16 would occur if/when AVX512 is supported, and logic in remainder handling assumes that maximum + x.Length >= Vector.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); @@ -31,6 +32,7 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan.Count; + int i = 0; do { Vector xVec = AsVector(ref xRef, i); @@ -44,6 +46,21 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = AsVector(ref xRef, x.Length - Vector.Count); + Vector yVec = AsVector(ref yRef, x.Length - Vector.Count); + + Vector remainderMask = LoadRemainderMaskSingleVector(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector += xVec * yVec; + xSumOfSquaresVector += xVec * xVec; + ySumOfSquaresVector += yVec * yVec; + } + // Sum the vector lanes into the scalar result. for (int e = 0; e < Vector.Count; e++) { @@ -52,13 +69,16 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan x, ReadOnlySpan( - float identityValue, ReadOnlySpan x, TLoad load = default, TAggregate aggregate = default) + ReadOnlySpan x, TLoad load = default, TAggregate aggregate = default) where TLoad : struct, IUnaryOperator - where TAggregate : struct, IBinaryOperator + where TAggregate : struct, IAggregationOperator { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + if (x.Length == 0) + { + return 0; + } - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + float result; + + if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); // Load the first vector as the initial set of results Vector resultVector = load.Invoke(AsVector(ref xRef, 0)); int oneVectorFromEnd = x.Length - Vector.Count; + int i = Vector.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector.Count; - do + while (i <= oneVectorFromEnd) { resultVector = aggregate.Invoke(resultVector, load.Invoke(AsVector(ref xRef, i))); i += Vector.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + resultVector = aggregate.Invoke(resultVector, + Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + new Vector(aggregate.IdentityValue), + load.Invoke(AsVector(ref xRef, x.Length - Vector.Count)))); + } // Aggregate the lanes in the vector back into the scalar result - for (int f = 0; f < Vector.Count; f++) + result = resultVector[0]; + for (int f = 1; f < Vector.Count; f++) { result = aggregate.Invoke(result, resultVector[f]); } + + return result; } // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) + result = load.Invoke(x[0]); + for (int i = 1; i < x.Length; i++) { result = aggregate.Invoke(result, load.Invoke(x[i])); } @@ -109,42 +145,62 @@ private static float Aggregate( } private static float Aggregate( - float identityValue, ReadOnlySpan x, ReadOnlySpan y, TBinary binary = default, TAggregate aggregate = default) + ReadOnlySpan x, ReadOnlySpan y, TBinary binary = default, TAggregate aggregate = default) where TBinary : struct, IBinaryOperator - where TAggregate : struct, IBinaryOperator + where TAggregate : struct, IAggregationOperator { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + Debug.Assert(x.Length == y.Length); - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + if (x.Length == 0) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + return 0; + } + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + float result; + if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) + { // Load the first vector as the initial set of results Vector resultVector = binary.Invoke(AsVector(ref xRef, 0), AsVector(ref yRef, 0)); int oneVectorFromEnd = x.Length - Vector.Count; + int i = Vector.Count; // Aggregate additional vectors into the result as long as there's at // least one full vector left to process. - i = Vector.Count; - do + while (i <= oneVectorFromEnd) { resultVector = aggregate.Invoke(resultVector, binary.Invoke(AsVector(ref xRef, i), AsVector(ref yRef, i))); i += Vector.Count; } - while (i <= oneVectorFromEnd); + + // Process the last vector in the spans, masking off elements already processed. + if (i != x.Length) + { + resultVector = aggregate.Invoke(resultVector, + Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + new Vector(aggregate.IdentityValue), + binary.Invoke( + AsVector(ref xRef, x.Length - Vector.Count), + AsVector(ref yRef, x.Length - Vector.Count)))); + } // Aggregate the lanes in the vector back into the scalar result - for (int f = 0; f < Vector.Count; f++) + result = resultVector[0]; + for (int f = 1; f < Vector.Count; f++) { result = aggregate.Invoke(result, resultVector[f]); } + + return result; } // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) + result = binary.Invoke(x[0], y[0]); + for (int i = 1; i < x.Length; i++) { result = aggregate.Invoke(result, binary.Invoke(x[i], y[i])); } @@ -164,11 +220,10 @@ private static float MinMaxCore(ReadOnlySpan x, TMinMax minMax = // otherwise returns the greater of the inputs. // It treats +0 as greater than -0 as per the specification. - // Initialize the result to the identity value float result = x[0]; int i = 0; - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); @@ -178,10 +233,10 @@ private static float MinMaxCore(ReadOnlySpan x, TMinMax minMax = if (Vector.EqualsAll(resultVector, resultVector)) { int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; // Aggregate additional vectors into the result as long as there's at least one full vector left to process. - i = Vector.Count; - do + while (i <= oneVectorFromEnd) { // Load the next vector, and early exit on NaN. current = AsVector(ref xRef, i); @@ -193,7 +248,6 @@ private static float MinMaxCore(ReadOnlySpan x, TMinMax minMax = resultVector = minMax.Invoke(resultVector, current); i += Vector.Count; } - while (i <= oneVectorFromEnd); // If any elements remain, handle them in one final vector. if (i != x.Length) @@ -582,10 +636,20 @@ private static unsafe Vector IsNegative(Vector f) => private static float Log2(float x) => MathF.Log(x, 2); - private readonly struct AddOperator : IBinaryOperator + private static unsafe Vector LoadRemainderMaskSingleVector(int validItems) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (validItems * 16) + (16 - Vector.Count)); + } + + private readonly struct AddOperator : IAggregationOperator { public float Invoke(float x, float y) => x + y; public Vector Invoke(Vector x, Vector y) => x + y; + public float IdentityValue => 0; } private readonly struct SubtractOperator : IBinaryOperator @@ -609,10 +673,11 @@ public Vector Invoke(Vector x, Vector y) } } - private readonly struct MultiplyOperator : IBinaryOperator + private readonly struct MultiplyOperator : IAggregationOperator { public float Invoke(float x, float y) => x * y; public Vector Invoke(Vector x, Vector y) => x * y; + public float IdentityValue => 1; } private readonly struct DivideOperator : IBinaryOperator @@ -826,6 +891,11 @@ private interface IBinaryOperator Vector Invoke(Vector x, Vector y); } + private interface IAggregationOperator : IBinaryOperator + { + float IdentityValue { get; } + } + private interface ITernaryOperator { float Invoke(float x, float y, float z);