From e1f0f71fdadba287b56b8f5444edd651a7ed8020 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 29 Sep 2023 09:16:56 -0400 Subject: [PATCH] Enable TensorPrimitives to perform in-place operations Some operations would produce incorrect results if the same span was passed as both an input and an output. When vectorization was employed but the span's length wasn't a perfect multiple of a vector, we'd do the standard trick of performing one last operation on the last vector's worth of data; however, that relies on the operation being idempotent, and if a previous operation has overwritten input with a new value due to the same memory being used for input and output, some operations won't be idempotent. This fixes that by masking off the already processed elements. It adds tests to validate in-place use works, and it updates the docs to carve out this valid overlapping. --- .../Numerics/Tensors/TensorPrimitives.cs | 106 ++-- .../Tensors/TensorPrimitives.netcore.cs | 159 ++++-- .../Tensors/TensorPrimitives.netstandard.cs | 51 +- .../tests/TensorPrimitivesTests.cs | 520 +++++++++++++++++- 4 files changed, 740 insertions(+), 96 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 0da8b6dfcdec2..41fe81416b27a 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -22,7 +22,8 @@ public static partial class TensorPrimitives /// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed. /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// public static void Abs(ReadOnlySpan x, Span destination) => @@ -39,7 +40,9 @@ public static void Abs(ReadOnlySpan x, Span destination) => /// This method effectively computes [i] = [i] + [i]. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -58,7 +61,9 @@ public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span /// This method effectively computes [i] = [i] + . /// /// - /// and may not overlap; if they do, behavior is undefined. + /// and may overlap, but only if they start at the same memory location; + /// otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters, such as to perform + /// an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -79,7 +84,9 @@ public static void Add(ReadOnlySpan x, float y, Span destination) /// This method effectively computes [i] = ([i] + [i]) * [i]. /// /// - /// , , and may overlap, but none of them may overlap with ; if they do, behavior is undefined. + /// , , and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -100,7 +107,9 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, Rea /// This method effectively computes [i] = ([i] + [i]) * . /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -121,7 +130,9 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, flo /// This method effectively computes [i] = ([i] + ) * [i]. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -139,7 +150,8 @@ public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan[i] = .Cosh([i]). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If a value is equal to or , the result stored into the corresponding destination location is set to . @@ -250,7 +262,9 @@ public static float Distance(ReadOnlySpan x, ReadOnlySpan y) /// This method effectively computes [i] = [i] / [i]. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -269,7 +283,8 @@ public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span[i] = [i] / . /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -320,7 +335,8 @@ public static float Dot(ReadOnlySpan x, ReadOnlySpan y) /// This method effectively computes [i] = .Exp([i]). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If a value equals or , the result stored into the corresponding destination location is set to NaN. @@ -559,7 +575,8 @@ public static unsafe int IndexOfMinMagnitude(ReadOnlySpan x) /// This method effectively computes [i] = .Log([i]). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If a value equals 0, the result stored into the corresponding destination location is set to . @@ -594,7 +611,8 @@ public static void Log(ReadOnlySpan x, Span destination) /// This method effectively computes [i] = .Log2([i]). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If a value equals 0, the result stored into the corresponding destination location is set to . @@ -648,7 +666,9 @@ public static float Max(ReadOnlySpan x) => /// This method effectively computes [i] = MathF.Max([i], [i]). /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , @@ -689,12 +709,9 @@ public static float MaxMagnitude(ReadOnlySpan x) => /// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]). /// /// - /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If either value is equal to , - /// that value is stored as the result. If the two values have the same magnitude and one is positive and the other is negative, - /// the positive value is considered to have the larger magnitude. - /// - /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different @@ -736,7 +753,9 @@ public static float Min(ReadOnlySpan x) => /// that value is stored as the result. Positive 0 is considered greater than negative 0. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different @@ -778,7 +797,9 @@ public static float MinMagnitude(ReadOnlySpan x) => /// the negative value is considered to have the smaller magnitude. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different @@ -799,7 +820,9 @@ public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Sp /// This method effectively computes [i] = [i] * [i]. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example, + /// to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -819,7 +842,8 @@ public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Spanscal method defined by BLAS1. /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -840,12 +864,16 @@ public static void Multiply(ReadOnlySpan x, float y, Span destinat /// This method effectively computes [i] = ([i] * [i]) + [i]. /// /// - /// , , and may overlap, but none of them may overlap with ; if they do, behavior is undefined. + /// , , and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. /// /// + + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); @@ -862,7 +890,9 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, Rea /// It corresponds to the axpy method defined by BLAS1. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -883,7 +913,9 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, flo /// This method effectively computes [i] = ([i] * ) + [i]. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -901,7 +933,8 @@ public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan[i] = -[i]. /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -1035,7 +1068,8 @@ public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) /// This method effectively computes [i] = 1f / (1f + .Exp(-[i])). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different @@ -1069,7 +1103,8 @@ public static void Sigmoid(ReadOnlySpan x, Span destination) /// This method effectively computes [i] = .Sinh([i]). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If a value is equal to , , or , @@ -1107,7 +1142,8 @@ public static void Sinh(ReadOnlySpan x, Span destination) /// It then effectively computes [i] = MathF.Exp([i]) / sum. /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different @@ -1150,7 +1186,9 @@ public static void SoftMax(ReadOnlySpan x, Span destination) /// This method effectively computes [i] = [i] - [i]. /// /// - /// and may overlap, but neither may overlap with ; if they do, behavior is undefined. + /// and may overlap arbitrarily, but they may only overlap with + /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. + /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -1169,7 +1207,8 @@ public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span[i] = [i] - . /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. @@ -1244,7 +1283,8 @@ public static float SumOfSquares(ReadOnlySpan x) => /// This method effectively computes [i] = .Tanh([i]). /// /// - /// and may not overlap; if they do, behavior is undefined. + /// may overlap with , but only if the input and the output span begin at the same memory + /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters. /// /// /// If a value is equal to , the corresponding destination location is set to -1. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index bed07eedfefd1..bd18b16d47b69 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -23,6 +23,9 @@ public static partial class TensorPrimitives /// /// This method effectively computes [i] = (Half)[i]. /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// /// public static void ConvertToHalf(ReadOnlySpan source, Span destination) { @@ -48,6 +51,9 @@ public static void ConvertToHalf(ReadOnlySpan source, Span destinat /// /// This method effectively computes [i] = (float)[i]. /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// /// public static void ConvertToSingle(ReadOnlySpan source, Span destination) { @@ -519,7 +525,10 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - result = TMinMax.Invoke(result, current); + result = Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + result, + TMinMax.Invoke(result, current)); } // Aggregate the lanes in the vector to create the final scalar result. @@ -565,7 +574,10 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - result = TMinMax.Invoke(result, current); + result = Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + result, + TMinMax.Invoke(result, current)); } // Aggregate the lanes in the vector to create the final scalar result. @@ -610,7 +622,10 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax : return GetFirstNaN(current); } - result = TMinMax.Invoke(result, current); + result = Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + result, + TMinMax.Invoke(result, current)); } // Aggregate the lanes in the vector to create the final scalar result. @@ -672,7 +687,10 @@ private static unsafe void InvokeSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.LoadUnsafe(ref dRef, lastVectorIndex), + TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -698,7 +716,10 @@ private static unsafe void InvokeSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.LoadUnsafe(ref dRef, lastVectorIndex), + TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -723,7 +744,10 @@ private static unsafe void InvokeSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.LoadUnsafe(ref dRef, lastVectorIndex), + TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -777,8 +801,11 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.LoadUnsafe(ref dRef, lastVectorIndex), + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -805,8 +832,11 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.LoadUnsafe(ref dRef, lastVectorIndex), + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -832,8 +862,11 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.LoadUnsafe(ref dRef, lastVectorIndex), + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -884,8 +917,11 @@ private static unsafe void InvokeSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.LoadUnsafe(ref dRef, lastVectorIndex), + TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -914,8 +950,11 @@ private static unsafe void InvokeSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.LoadUnsafe(ref dRef, lastVectorIndex), + TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -943,8 +982,11 @@ private static unsafe void InvokeSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.LoadUnsafe(ref dRef, lastVectorIndex), + TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1001,9 +1043,12 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1031,9 +1076,12 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1060,9 +1108,12 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1121,9 +1172,12 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + Vector512.LoadUnsafe(ref yRef, lastVectorIndex), + zVec)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1153,9 +1207,12 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + Vector256.LoadUnsafe(ref yRef, lastVectorIndex), + zVec)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1184,9 +1241,12 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + Vector128.LoadUnsafe(ref yRef, lastVectorIndex), + zVec)).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1245,9 +1305,12 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector512.ConditionalSelect( + Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + Vector512.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1277,9 +1340,12 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256.ConditionalSelect( + Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + Vector256.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; @@ -1308,9 +1374,12 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128.ConditionalSelect( + Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + Vector128.LoadUnsafe(ref dRef, lastVectorIndex), + TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), + yVec, + Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); } return; diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index e05e54bcad769..70207a5c8995b 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -320,7 +320,11 @@ private static void InvokeSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex)); + ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); + dest = Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + dest, + op.Invoke(AsVector(ref xRef, lastVectorIndex))); } return; @@ -374,8 +378,12 @@ private static void InvokeSpanSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex)); + ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); + dest = Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + dest, + op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex))); } return; @@ -424,8 +432,11 @@ private static void InvokeSpanScalarIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec); + ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); + dest = Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + dest, + op.Invoke(AsVector(ref xRef, lastVectorIndex), yVec)); } return; @@ -482,9 +493,13 @@ private static void InvokeSpanSpanSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - AsVector(ref zRef, lastVectorIndex)); + ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); + dest = Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + dest, + op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex), + AsVector(ref zRef, lastVectorIndex))); } return; @@ -543,9 +558,13 @@ private static void InvokeSpanSpanScalarIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - zVec); + ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); + dest = Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + dest, + op.Invoke(AsVector(ref xRef, lastVectorIndex), + AsVector(ref yRef, lastVectorIndex), + zVec)); } return; @@ -604,9 +623,13 @@ private static void InvokeSpanScalarSpanIntoSpan( if (i != x.Length) { int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec, - AsVector(ref zRef, lastVectorIndex)); + ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); + dest = Vector.ConditionalSelect( + Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero), + dest, + op.Invoke(AsVector(ref xRef, lastVectorIndex), + yVec, + AsVector(ref zRef, lastVectorIndex))); } return; diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index edcebe8eb4775..751e352dd1da5 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -75,6 +75,21 @@ public static void Abs(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Abs_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Abs(x, x); + + for (int i = 0; i < x.Length; i++) + { + Assert.Equal(MathF.Abs(xOrig[i]), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Abs_ThrowsForTooShortDestination(int tensorLength) @@ -96,11 +111,34 @@ public static void Add_TwoTensors(int tensorLength) using BoundedMemory destination = CreateTensor(tensorLength); TensorPrimitives.Add(x, y, destination); - for (int i = 0; i < tensorLength; i++) { Assert.Equal(x[i] + y[i], destination[i], Tolerance); } + + float[] xOrig = x.Span.ToArray(); + + // Validate that the destination can be the same as an input. + TensorPrimitives.Add(x, x, x); + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] + xOrig[i], x[i], Tolerance); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Add(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] + xOrig[i], x[i], Tolerance); + } } [Theory] @@ -142,6 +180,22 @@ public static void Add_TensorScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Add(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] + y, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLength) @@ -172,6 +226,21 @@ public static void AddMultiply_ThreeTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_ThreeTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.AddMultiply(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal((xOrig[i] + xOrig[i]) * xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths(int tensorLength) @@ -215,6 +284,22 @@ public static void AddMultiply_TensorTensorScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float multiplier = NextSingle(); + + TensorPrimitives.AddMultiply(x, x, multiplier, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal((xOrig[i] + xOrig[i]) * multiplier, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) @@ -257,6 +342,22 @@ public static void AddMultiply_TensorScalarTensor(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.AddMultiply(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal((xOrig[i] + y) * xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) @@ -299,6 +400,21 @@ public static void Cosh(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Cosh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Cosh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Cosh(xOrig[i]), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Cosh_ThrowsForTooShortDestination(int tensorLength) @@ -421,6 +537,21 @@ public static void Divide_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Divide(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] / xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Divide_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) @@ -460,6 +591,22 @@ public static void Divide_TensorScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Divide(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] / y, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLength) @@ -527,6 +674,21 @@ public static void Exp(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Exp_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Exp(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Exp(xOrig[i]), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Exp_ThrowsForTooShortDestination(int tensorLength) @@ -735,6 +897,21 @@ public static void Log(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Log(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Log(xOrig[i]), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Log_ThrowsForTooShortDestination(int tensorLength) @@ -762,6 +939,21 @@ public static void Log2(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Log2(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Log(xOrig[i], 2), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Log2_ThrowsForTooShortDestination(int tensorLength) @@ -834,6 +1026,32 @@ public static void Max_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Max_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Max(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Max(xOrig[i], y[i]), x[i], Tolerance); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Max(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Max(x[i], yOrig[i]), y[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Max_TwoTensors_SpecialValues(int tensorLength) @@ -955,6 +1173,32 @@ public static void MaxMagnitude_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MaxMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathFMaxMagnitude(xOrig[i], y[i]), x[i], Tolerance); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MaxMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathFMaxMagnitude(x[i], yOrig[i]), y[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void MaxMagnitude_TwoTensors_SpecialValues(int tensorLength) @@ -1075,6 +1319,32 @@ public static void Min_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Min(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Min(xOrig[i], y[i]), x[i], Tolerance); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Min(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Min(x[i], yOrig[i]), y[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Min_TwoTensors_SpecialValues(int tensorLength) @@ -1194,6 +1464,32 @@ public static void MinMagnitude_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MinMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathFMinMagnitude(xOrig[i], y[i]), x[i], Tolerance); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MinMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathFMinMagnitude(x[i], yOrig[i]), y[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void MinMagnitude_TwoTensors_SpecialValues(int tensorLength) @@ -1270,6 +1566,21 @@ public static void Multiply_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Multiply(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] * xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Multiply_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) @@ -1309,6 +1620,22 @@ public static void Multiply_TensorScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Multiply(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] * y, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensorLength) @@ -1339,6 +1666,21 @@ public static void MultiplyAdd_ThreeTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.MultiplyAdd(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal((xOrig[i] * xOrig[i]) + xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y(int tensorLength) @@ -1382,6 +1724,22 @@ public static void MultiplyAdd_TensorTensorScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float addend = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, x, addend, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal((xOrig[i] * xOrig[i]) + addend, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) @@ -1411,6 +1769,22 @@ public static void MultiplyAdd_TensorScalarTensor(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal((xOrig[i] * y) + xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) @@ -1440,6 +1814,21 @@ public static void Negate(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Negate(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(-xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Negate_ThrowsForTooShortDestination(int tensorLength) @@ -1598,6 +1987,36 @@ public static void ProductOfSums_KnownValues() #endregion #region Sigmoid + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Sigmoid(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(1f / (1f + MathF.Exp(-x[i])), destination[i], Tolerance); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sigmoid(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(1f / (1f + MathF.Exp(-xOrig[i])), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) @@ -1612,7 +2031,7 @@ public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] - public static void Sigmoid(float[] x, float[] expectedResult) + public static void Sigmoid_KnownValues(float[] x, float[] expectedResult) { using BoundedMemory dest = CreateTensor(x.Length); TensorPrimitives.Sigmoid(x, dest); @@ -1663,6 +2082,21 @@ public static void Sinh(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Sinh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sinh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Sinh(xOrig[i]), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Sinh_ThrowsForTooShortDestination(int tensorLength) @@ -1675,6 +2109,38 @@ public static void Sinh_ThrowsForTooShortDestination(int tensorLength) #endregion #region SoftMax + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.SoftMax(x, destination); + + float expSum = MemoryMarshal.ToEnumerable(x.Memory).Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Exp(x[i]) / expSum, destination[i], Tolerance); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.SoftMax(x, x); + + float expSum = xOrig.Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Exp(xOrig[i]) / expSum, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) @@ -1690,7 +2156,7 @@ public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f })] - public static void SoftMax(float[] x, float[] expectedResult) + public static void SoftMax_KnownValues(float[] x, float[] expectedResult) { using BoundedMemory dest = CreateTensor(x.Length); TensorPrimitives.SoftMax(x, dest); @@ -1739,6 +2205,21 @@ public static void Subtract_TwoTensors(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Subtract(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] - xOrig[i], x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Subtract_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) @@ -1778,6 +2259,22 @@ public static void Subtract_TensorScalar(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Subtract(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(xOrig[i] - y, x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensorLength) @@ -1797,7 +2294,7 @@ public static void Sum(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Sum(x), Tolerance); + Assert.Equal(MemoryMarshal.ToEnumerable(x.Memory).Sum(), TensorPrimitives.Sum(x), Tolerance); float sum = 0; foreach (float f in x.Span) @@ -1890,6 +2387,21 @@ public static void Tanh(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Tanh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + Assert.Equal(MathF.Tanh(xOrig[i]), x[i], Tolerance); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void Tanh_ThrowsForTooShortDestination(int tensorLength)