diff --git a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln index cc3000d60ef88..015b65250931a 100644 --- a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln +++ b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln @@ -1,18 +1,34 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.8.34205.153 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{9F20CEA1-2216-4432-BBBD-F01E05D17F23}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\ref\Microsoft.Bcl.Numerics.csproj", "{D311ABE4-10A9-4BB1-89CE-6358C55501A8}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\src\Microsoft.Bcl.Numerics.csproj", "{1578185F-C4FA-4866-936B-E62AAEDD03B7}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors", "ref\System.Numerics.Tensors.csproj", "{21CB448A-3882-4337-B416-D1A3E0BCFFC5}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors", "src\System.Numerics.Tensors.csproj", "{848DD000-3D22-4A25-A9D9-05AFF857A116}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors.Tests", "tests\System.Numerics.Tensors.Tests.csproj", "{4AF6A02D-82C8-4898-9EDF-01F107C25061}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ComInterfaceGenerator", "..\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj", "{8CA7C982-3EE4-4BCE-9493-7A63556736D3}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LibraryImportGenerator", "..\System.Runtime.InteropServices\gen\LibraryImportGenerator\LibraryImportGenerator.csproj", "{4588351F-4233-4957-B84C-7F8E22B8888A}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Interop.SourceGeneration", "..\System.Runtime.InteropServices\gen\Microsoft.Interop.SourceGeneration\Microsoft.Interop.SourceGeneration.csproj", "{DB954E01-898A-4FE2-A3AA-180D041AB08F}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.CodeFixProvider", "..\..\tools\illink\src\ILLink.CodeFix\ILLink.CodeFixProvider.csproj", "{04FC0651-B9D0-448A-A28B-11B1D4A897F4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.RoslynAnalyzer", "..\..\tools\illink\src\ILLink.RoslynAnalyzer\ILLink.RoslynAnalyzer.csproj", "{683A7D28-CC55-4375-848D-E659075ECEE4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.Tasks", "..\..\tools\illink\src\ILLink.Tasks\ILLink.Tasks.csproj", "{1CBEAEA8-2CA1-4B07-9930-35A785205852}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\Mono.Linker.csproj", "{BA7828B1-7953-47A0-AE5A-E22B501C4BD0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\ref\Mono.Linker.csproj", "{57E57290-3A6A-43F8-8764-D4DC8151F89C}" +EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{DE94CA7D-BB10-4865-85A6-6B694631247F}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{6BC42E6D-848C-4533-B715-F116E7DB3610}" @@ -21,6 +37,14 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AB415F5A-75E EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{083161E5-6049-4D84-9739-9D7797D7117D}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -31,6 +55,14 @@ Global {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Debug|Any CPU.Build.0 = Debug|Any CPU {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Release|Any CPU.ActiveCfg = Release|Any CPU {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Release|Any CPU.Build.0 = Release|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Release|Any CPU.Build.0 = Release|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Release|Any CPU.Build.0 = Release|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Debug|Any CPU.Build.0 = Debug|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -43,10 +75,6 @@ Global {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Debug|Any CPU.Build.0 = Debug|Any CPU {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Release|Any CPU.ActiveCfg = Release|Any CPU {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Release|Any CPU.Build.0 = Release|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Release|Any CPU.Build.0 = Release|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Debug|Any CPU.Build.0 = Debug|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -55,20 +83,53 @@ Global {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Debug|Any CPU.Build.0 = Debug|Any CPU {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Release|Any CPU.ActiveCfg = Release|Any CPU {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Release|Any CPU.Build.0 = Release|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Release|Any CPU.Build.0 = Release|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Release|Any CPU.Build.0 = Release|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Release|Any CPU.Build.0 = Release|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Release|Any CPU.Build.0 = Release|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution {9F20CEA1-2216-4432-BBBD-F01E05D17F23} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} + {D311ABE4-10A9-4BB1-89CE-6358C55501A8} = {6BC42E6D-848C-4533-B715-F116E7DB3610} + {1578185F-C4FA-4866-936B-E62AAEDD03B7} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {848DD000-3D22-4A25-A9D9-05AFF857A116} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} - {8CA7C982-3EE4-4BCE-9493-7A63556736D3} = {083161E5-6049-4D84-9739-9D7797D7117D} + {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {4588351F-4233-4957-B84C-7F8E22B8888A} = {083161E5-6049-4D84-9739-9D7797D7117D} {DB954E01-898A-4FE2-A3AA-180D041AB08F} = {083161E5-6049-4D84-9739-9D7797D7117D} + {04FC0651-B9D0-448A-A28B-11B1D4A897F4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} + {683A7D28-CC55-4375-848D-E659075ECEE4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} + {1CBEAEA8-2CA1-4B07-9930-35A785205852} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {57E57290-3A6A-43F8-8764-D4DC8151F89C} = {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} + {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {10A5F2C3-5230-4916-9D4D-BBDB94851037} EndGlobalSection -EndGlobal + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{683a7d28-cc55-4375-848d-e659075ecee4}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{ba7828b1-7953-47a0-ae5a-e22b501c4bd0}*SharedItemsImports = 5 + EndGlobalSection +EndGlobal \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt new file mode 100644 index 0000000000000..a8f2d0192cfec --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt @@ -0,0 +1,2 @@ +M:System.Numerics.Tensors.TensorPrimitives.ConvertToHalf(System.ReadOnlySpan{System.Single},System.Span{System.Half}) +M:System.Numerics.Tensors.TensorPrimitives.ConvertToSingle(System.ReadOnlySpan{System.Half},System.Span{System.Single}) \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index 45f0d8fa17893..86b9f4d82b1f6 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -126,4 +126,7 @@ Input span arguments must all have the same length. - \ No newline at end of file + + The destination span may only overlap with an input span if the two spans start at the same memory location. + + diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index be4a04702af5e..52c6cb65811e6 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -9,6 +9,7 @@ Once this package has shipped a stable version, the following line should be removed in order to re-enable validation. --> true + ReferenceAssemblyExclusions.txt diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index d28d4bacafdb8..03db1abb7f858 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -1,953 +1,1097 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + namespace System.Numerics.Tensors { /// Performs primitive tensor operations over spans of memory. public static partial class TensorPrimitives { - /// Computes the element-wise result of: + . + /// Computes the element-wise absolute value of each single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Abs([i]). + /// + /// + /// The absolute value of a is its numeric value without its sign. For example, the absolute value of both 1.2e-03 and -1.2e03 is 1.2e03. + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed. + /// + /// + public static void Abs(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = [i] + [i]. - public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => InvokeSpanSpanIntoSpan(x, y, destination); - /// Computes the element-wise result of: + . + /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a scalar. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = [i] + . + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// public static void Add(ReadOnlySpan x, float y, Span destination) => InvokeSpanScalarIntoSpan(x, y, destination); - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] - [i]. - public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] - . - public static void Subtract(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: * . + /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. + /// The third tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] * . - public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. + /// Length of must be same as length of and the length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// - /// This method effectively does [i] = [i] * . - /// This method corresponds to the scal method defined by BLAS1. + /// + /// This method effectively computes [i] = ([i] + [i]) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// /// - public static void Multiply(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); - /// Computes the element-wise result of: / . + /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + [i]) * . + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); - /// Computes the element-wise result of: / . + /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. /// The destination tensor, represented as a span. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: -. - /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = -[i]. - public static void Negate(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + ) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); - /// Computes the element-wise result of: MathF.Abs(). + /// Computes the element-wise hyperbolic cosine of each single-precision floating-point radian angle in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = MathF.Abs([i]). - public static void Abs(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Cosh([i]). + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is also NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Cosh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the element-wise result of: ( + ) * . + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * [i]. - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); + /// The cosine similarity of the two tensors. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes TensorPrimitives.Dot(x, y) / (MathF.Sqrt(TensorPrimitives.SumOfSquares(x)) * MathF.Sqrt(TensorPrimitives.SumOfSquares(y)). + /// + /// + /// If any element in either input tensor is equal to , , or , + /// NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + return CosineSimilarityCore(x, y); + } - /// Computes the element-wise result of: ( + ) * . + /// Computes the distance between two points, specified as non-empty, equal-length tensors of single-precision floating-point numbers, in Euclidean space. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * . - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); + /// The Euclidean distance. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<float> difference = ...; + /// TensorPrimitives.Subtract(x, y, difference); + /// float result = MathF.Sqrt(TensorPrimitives.SumOfSquares(difference)); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// If any element in either input tensor is equal to , NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Distance(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + ) * [i]. - public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); + return MathF.Sqrt(Aggregate(x, y)); + } - /// Computes the element-wise result of: ( * ) + . + /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = ([i] * [i]) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] / [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - /// Computes the element-wise result of: ( * ) + . + /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. + /// The second tensor, represented as a scalar. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// - /// This method effectively does [i] = ([i] * [i]) + . - /// This method corresponds to the axpy method defined by BLAS1. + /// + /// This method effectively computes [i] = [i] / . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// /// - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); - /// Computes the element-wise result of: ( * ) + . + /// Computes the dot product of two tensors containing single-precision floating-point numbers. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * ) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); + /// The dot product. + /// Length of must be same as length of . + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<float> products = ...; + /// TensorPrimitives.Multiply(x, y, products); + /// float result = TensorPrimitives.Sum(products); + /// + /// but without requiring additional temporary storage for the intermediate products. It corresponds to the dot method defined by BLAS1. + /// + /// + /// If any of the input elements is equal to , the resulting value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Dot(ReadOnlySpan x, ReadOnlySpan y) => + Aggregate(x, y); - /// Computes the element-wise result of: pow(e, ). + /// Computes the element-wise result of raising e to the single-precision floating-point number powers in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = .Exp([i]). - public static void Exp(ReadOnlySpan x, Span destination) + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Exp([i]). + /// + /// + /// If a value equals or , the result stored into the corresponding destination location is set to NaN. + /// If a value equals , the result stored into the corresponding destination location is set to 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Exp(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Searches for the index of the largest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The index of the maximum element in , or -1 if is empty. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the index of the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMax(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Exp(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: ln(). + /// Searches for the index of the single-precision floating-point number with the largest magnitude in the specified tensor. /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = .Log([i]). - public static void Log(ReadOnlySpan x, Span destination) + /// The index of the element in with the largest magnitude (absolute value), or -1 if is empty. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMaxMagnitude(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Log(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: log2(). + /// Searches for the index of the smallest single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = .Log2([i]). - public static void Log2(ReadOnlySpan x, Span destination) + /// The index of the minimum element in , or -1 if is empty. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value equal to + /// is present, the index of the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMin(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = Log2(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: cosh(). + /// Searches for the index of the single-precision floating-point number with the smallest magnitude in the specified tensor. /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = .Cosh([i]). - public static void Cosh(ReadOnlySpan x, Span destination) + /// The index of the element in with the smallest magnitude (absolute value), or -1 if is empty. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMinMagnitude(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Cosh(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: sinh(). + /// Computes the element-wise natural (base e) logarithm of single-precision floating-point numbers in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = .Sinh([i]). - public static void Sinh(ReadOnlySpan x, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Sinh(x[i]); - } - } + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the element-wise result of: tanh(). + /// Computes the element-wise base 2 logarithm of single-precision floating-point numbers in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = .Tanh([i]). - public static void Tanh(ReadOnlySpan x, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log2([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log2(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Tanh(x[i]); - } - } + /// Searches for the largest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The maximum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Max(ReadOnlySpan x) => + MinMaxCore(x); - /// Computes the cosine similarity between two non-zero vectors. + /// Computes the element-wise maximum of the single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The cosine similarity between the two vectors. - /// Length of '' must be same as length of ''. - /// '' and '' must not be empty. - public static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return CosineSimilarityCore(x, y); - } + /// Searches for the single-precision floating-point number with the largest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the largest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float MaxMagnitude(ReadOnlySpan x) => + MinMaxCore(x); - /// - /// Compute the distance between two points in Euclidean space. - /// + /// Computes the element-wise single-precision floating-point number with the largest magnitude in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The Euclidean distance. - /// Length of '' must be same as length of ''. - /// '' and '' must not be empty. - public static float Distance(ReadOnlySpan x, ReadOnlySpan y) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return MathF.Sqrt(Aggregate(0f, x, y)); - } + /// Searches for the smallest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The minimum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value is equal to + /// is present, the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Min(ReadOnlySpan x) => + MinMaxCore(x); - /// - /// A mathematical operation that takes two vectors and returns a scalar. - /// + /// Computes the element-wise minimum of the single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The dot product. - /// Length of '' must be same as length of ''. - public static float Dot(ReadOnlySpan x, ReadOnlySpan y) // BLAS1: dot - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return Aggregate(0f, x, y); - } + /// Searches for the single-precision floating-point number with the smallest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the smallest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float MinMagnitude(ReadOnlySpan x) => + MinMaxCore(x); - /// - /// A mathematical operation that takes a vector and returns the L2 norm. - /// + /// Computes the element-wise single-precision floating-point number with the smallest magnitude in the specified tensors. /// The first tensor, represented as a span. - /// The L2 norm. - public static float Norm(ReadOnlySpan x) // BLAS1: nrm2 - { - return MathF.Sqrt(Aggregate(0f, x)); - } + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MinMagnitude([i], [i]). + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If either value is equal to , + /// that value is stored as the result. If the two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - /// - /// A function that takes a collection of real numbers and returns a probability distribution. - /// + /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. - /// The destination tensor. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . /// Destination is too short. - /// '' must not be empty. - public static void SoftMax(ReadOnlySpan x, Span destination) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - float expSum = 0f; - - for (int i = 0; i < x.Length; i++) - { - expSum += MathF.Exp(x[i]); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Exp(x[i]) / expSum; - } - } + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - /// - /// A function that takes a real number and returns a value between 0 and 1. - /// + /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. - /// The destination tensor. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. /// Destination is too short. - /// '' must not be empty. - public static void Sigmoid(ReadOnlySpan x, Span destination) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = 1f / (1 + MathF.Exp(-x[i])); - } - } - - /// Computes the maximum element in . - /// The tensor, represented as a span. - /// The maximum element in . - /// Length of '' must be greater than zero. - public static float Max(ReadOnlySpan x) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - - float result = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximum` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the greater of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != result) - { - if (float.IsNaN(current)) - { - return current; - } - - if (result < current) - { - result = current; - } - } - else if (IsNegative(result)) - { - result = current; - } - } + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * . + /// It corresponds to the scal method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); - return result; - } + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of and length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); - /// Computes the element-wise result of: MathF.Max(, ). + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = MathF.Max([i], [i]). - public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + . + /// It corresponds to the axpy method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * ) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Max(x[i], y[i]); - } - } + /// Computes the element-wise negation of each single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = -[i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the Euclidean norm of the specified tensor of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The norm. + /// + /// + /// This method effectively computes MathF.Sqrt(TensorPrimitives.SumOfSquares(x)). + /// This is often referred to as the Euclidean norm or L2 norm. + /// It corresponds to the nrm2 method defined by BLAS1. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Norm(ReadOnlySpan x) => + MathF.Sqrt(SumOfSquares(x)); - /// Computes the minimum element in . + /// Computes the product of all elements in the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The minimum element in . - /// Length of '' must be greater than zero. - public static float Min(ReadOnlySpan x) + /// The result of multiplying all elements in . + /// Length of must be greater than zero. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Product(ReadOnlySpan x) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float result = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimum` function - // It propagates NaN inputs back to the caller and - // otherwise returns the lesser of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != result) - { - if (float.IsNaN(current)) - { - return current; - } - - if (current < result) - { - result = current; - } - } - else if (IsNegative(current)) - { - result = current; - } - } - - return result; + return Aggregate(x); } - /// Computes the element-wise result of: MathF.Min(, ). + /// Computes the product of the element-wise differences of the single-precision floating-point numbers in the specified non-empty tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = MathF.Min([i], [i]). - public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Min(x[i], y[i]); - } - } - - /// Computes the maximum magnitude of any element in . - /// The tensor, represented as a span. - /// The maximum magnitude of any element in . - /// Length of '' must be greater than zero. - public static float MaxMagnitude(ReadOnlySpan x) + /// The result of multiplying the element-wise subtraction of the elements in the second tensor from the first tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<float> differences = ...; + /// TensorPrimitives.Subtract(x, y, differences); + /// float result = TensorPrimitives.Product(differences); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float result = float.NegativeInfinity; - float resultMag = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximumMagnitude` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a greater magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != resultMag) - { - if (float.IsNaN(currentMag)) - { - return currentMag; - } - - if (resultMag < currentMag) - { - result = current; - resultMag = currentMag; - } - } - else if (IsNegative(result)) - { - result = current; - resultMag = currentMag; - } - } - - return result; + return Aggregate(x, y); } - /// Computes the element-wise result of: MathF.MaxMagnitude(, ). + /// Computes the product of the element-wise sums of the single-precision floating-point numbers in the specified non-empty tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = MathF.MaxMagnitude([i], [i]). - public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) + /// The result of multiplying the element-wise additions of the elements in each tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<float> sums = ...; + /// TensorPrimitives.Add(x, y, sums); + /// float result = TensorPrimitives.Product(sums); + /// + /// but without requiring additional temporary storage for the intermediate sums. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MaxMagnitude(x[i], y[i]); - } + return Aggregate(x, y); } - /// Computes the minimum magnitude of any element in . + /// Computes the element-wise sigmoid function on the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The minimum magnitude of any element in . - /// Length of '' must be greater than zero. - public static float MinMagnitude(ReadOnlySpan x) + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = 1f / (1f + .Exp(-[i])). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sigmoid(ReadOnlySpan x, Span destination) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float result = float.PositiveInfinity; - float resultMag = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimumMagnitude` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a lesser magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != resultMag) - { - if (float.IsNaN(currentMag)) - { - return currentMag; - } - - if (currentMag < resultMag) - { - result = current; - resultMag = currentMag; - } - } - else if (IsNegative(current)) - { - result = current; - resultMag = currentMag; - } - } - - return result; + InvokeSpanIntoSpan(x, destination); } - /// Computes the element-wise result of: MathF.MinMagnitude(, ). - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. + /// Computes the element-wise hyperbolic sine of each single-precision floating-point radian angle in the specified tensor. + /// The tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. /// Destination is too short. - /// This method effectively does [i] = MathF.MinMagnitude([i], [i]). - public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - for (int i = 0; i < x.Length; i++) - { - destination[i] = MinMagnitude(x[i], y[i]); - } - } + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Sinh([i]). + /// + /// + /// If a value is equal to , , or , + /// the corresponding destination location is set to that value. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sinh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the index of the maximum element in . + /// Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The index of the maximum element in , or -1 if is empty. - public static unsafe int IndexOfMax(ReadOnlySpan x) + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes a sum of MathF.Exp(x[i]) for all elements in . + /// It then effectively computes [i] = MathF.Exp([i]) / sum. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void SoftMax(ReadOnlySpan x, Span destination) { - int result = -1; - - if (!x.IsEmpty) + if (x.IsEmpty) { - float max = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximum` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the greater of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != max) - { - if (float.IsNaN(current)) - { - return i; - } - - if (max < current) - { - result = i; - max = current; - } - } - else if (IsNegative(max) && !IsNegative(current)) - { - result = i; - max = current; - } - } + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - return result; - } - - /// Computes the index of the minimum element in . - /// The tensor, represented as a span. - /// The index of the minimum element in , or -1 if is empty. - public static unsafe int IndexOfMin(ReadOnlySpan x) - { - int result = -1; - - if (!x.IsEmpty) + if (x.Length > destination.Length) { - float min = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimum` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the lesser of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != min) - { - if (float.IsNaN(current)) - { - return i; - } - - if (current < min) - { - result = i; - min = current; - } - } - else if (IsNegative(current) && !IsNegative(min)) - { - result = i; - min = current; - } - } + ThrowHelper.ThrowArgument_DestinationTooShort(); } - return result; - } + ValidateInputOutputSpanNonOverlapping(x, destination); - /// Computes the index of the element in with the maximum magnitude. - /// The tensor, represented as a span. - /// The index of the element with the maximum magnitude, or -1 if is empty. - /// This method corresponds to the iamax method defined by BLAS1. - public static unsafe int IndexOfMaxMagnitude(ReadOnlySpan x) - { - int result = -1; - - if (!x.IsEmpty) - { - float max = float.NegativeInfinity; - float maxMag = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximumMagnitude` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a greater magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != maxMag) - { - if (float.IsNaN(currentMag)) - { - return i; - } - - if (maxMag < currentMag) - { - result = i; - max = current; - maxMag = currentMag; - } - } - else if (IsNegative(max) && !IsNegative(current)) - { - result = i; - max = current; - maxMag = currentMag; - } - } - } + float expSum = Aggregate(x); - return result; + InvokeSpanScalarIntoSpan(x, expSum, destination); } - /// Computes the index of the element in with the minimum magnitude. - /// The tensor, represented as a span. - /// The index of the element with the minimum magnitude, or -1 if is empty. - public static unsafe int IndexOfMinMagnitude(ReadOnlySpan x) - { - int result = -1; - - if (!x.IsEmpty) - { - float min = float.PositiveInfinity; - float minMag = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimumMagnitude` function - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a lesser magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != minMag) - { - if (float.IsNaN(currentMag)) - { - return i; - } - - if (currentMag < minMag) - { - result = i; - min = current; - minMag = currentMag; - } - } - else if (IsNegative(current) && !IsNegative(min)) - { - result = i; - min = current; - minMag = currentMag; - } - } - } + /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return result; - } + /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); - /// Computes the sum of all elements in . + /// Computes the sum of all elements in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. /// The result of adding all elements in , or zero if is empty. + /// + /// + /// If any of the values in the input is equal to , the result is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// public static float Sum(ReadOnlySpan x) => - Aggregate(0f, x); - - /// Computes the sum of the squares of every element in . - /// The tensor, represented as a span. - /// The result of adding every element in multiplied by itself, or zero if is empty. - /// This method effectively does .Sum(.Multiply(, )). - public static float SumOfSquares(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); - /// Computes the sum of the absolute values of every element in . + /// Computes the sum of the absolute values of every element in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. /// The result of adding the absolute value of every element in , or zero if is empty. /// - /// This method effectively does .Sum(.Abs()). - /// This method corresponds to the asum method defined by BLAS1. + /// + /// This method effectively computes: + /// + /// Span<float> absoluteValues = ...; + /// TensorPrimitives.Abs(x, absoluteValues); + /// float result = TensorPrimitives.Sum(absoluteValues); + /// + /// but without requiring intermediate storage for the absolute values. It corresponds to the asum method defined by BLAS1. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// /// public static float SumOfMagnitudes(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); - /// Computes the product of all elements in . + /// Computes the sum of the square of every element in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The result of multiplying all elements in . - /// Length of '' must be greater than zero. - public static float Product(ReadOnlySpan x) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } + /// The result of adding the square of every element in , or zero if is empty. + /// + /// + /// This method effectively computes: + /// + /// Span<float> squaredValues = ...; + /// TensorPrimitives.Multiply(x, x, squaredValues); + /// float result = TensorPrimitives.Sum(squaredValues); + /// + /// but without requiring intermediate storage for the squared values. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float SumOfSquares(ReadOnlySpan x) => + Aggregate(x); - return Aggregate(1.0f, x); - } + /// Computes the element-wise hyperbolic tangent of each single-precision floating-point radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Tanh([i]). + /// + /// + /// If a value is equal to , the corresponding destination location is set to -1. + /// If a value is equal to , the corresponding destination location is set to 1. + /// If a value is equal to , the corresponding destination location is set to NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Tanh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the product of the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The result of multiplying the element-wise additions of the elements in each tensor. - /// Length of both input spans must be greater than zero. - /// and must have the same length. - /// This method effectively does .Product(.Add(, )). - public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) + /// Throws an exception if the and spans overlap and don't begin at the same memory location. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan input, Span output) { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - - if (x.Length != y.Length) + if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(output)) && + input.Overlaps(output)) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + ThrowHelper.ThrowArgument_InputAndDestinationSpanMustNotOverlap(); } - - return Aggregate(1.0f, x, y); } - /// Computes the product of the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The result of multiplying the element-wise subtraction of the elements in the second tensor from the first tensor. - /// Length of both input spans must be greater than zero. - /// and must have the same length. - /// This method effectively does .Product(.Subtract(, )). - public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - return Aggregate(1.0f, x, y); - } + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 17 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 16 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt32Mask_16x16 => + [ + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 17 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderUInt32Mask_16x16 => + [ + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + ]; } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 8eb1769d5eaee..498e4b58da77c 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -1,14 +1,33 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; namespace System.Numerics.Tensors { - public static partial class TensorPrimitives + public static unsafe partial class TensorPrimitives { + /// Defines the threshold, in bytes, at which non-temporal stores will be used. + /// + /// A non-temporal store is one that allows the CPU to bypass the cache when writing to memory. + /// + /// This can be beneficial when working with large amounts of memory where the writes would otherwise + /// cause large amounts of repeated updates and evictions. The hardware optimization manuals recommend + /// the threshold to be roughly half the size of the last level of on-die cache -- that is, if you have approximately + /// 4MB of L3 cache per core, you'd want this to be approx. 1-2MB, depending on if hyperthreading was enabled. + /// + /// However, actually computing the amount of L3 cache per core can be tricky or error prone. Native memcpy + /// algorithms use a constant threshold that is typically around 256KB and we match that here for simplicity. This + /// threshold accounts for most processors in the last 10-15 years that had approx. 1MB L3 per core and support + /// hyperthreading, giving a per core last level cache of approx. 512KB. + /// + private const nuint NonTemporalByteThreshold = 256 * 1024; + /// /// Copies to , converting each /// value to its nearest representable half-precision floating-point value. @@ -16,6 +35,14 @@ public static partial class TensorPrimitives /// The source span from which to copy values. /// The destination span into which the converted values should be written. /// Destination is too short. + /// + /// + /// This method effectively computes [i] = (Half)[i]. + /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// + /// public static void ConvertToHalf(ReadOnlySpan source, Span destination) { if (source.Length > destination.Length) @@ -23,361 +50,354 @@ public static void ConvertToHalf(ReadOnlySpan source, Span destinat ThrowHelper.ThrowArgument_DestinationTooShort(); } - for (int i = 0; i < source.Length; i++) + ref float sourceRef = ref MemoryMarshal.GetReference(source); + ref ushort destinationRef = ref Unsafe.As(ref MemoryMarshal.GetReference(destination)); + int i = 0, twoVectorsFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) { - destination[i] = (Half)source[i]; + twoVectorsFromEnd = source.Length - (Vector512.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector512 lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + Vector512 upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512.Count))); + Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector512.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector512.Count * 2); + + Vector512 lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + Vector512 upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512.Count))); + Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } } - } +#endif - /// - /// Copies to , converting each half-precision - /// floating-point value to its nearest representable value. - /// - /// The source span from which to copy values. - /// The destination span into which the converted values should be written. - /// Destination is too short. - public static void ConvertToSingle(ReadOnlySpan source, Span destination) - { - if (source.Length > destination.Length) + if (Vector256.IsHardwareAccelerated) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + twoVectorsFromEnd = source.Length - (Vector256.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector256 lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + Vector256 upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256.Count))); + Vector256 halfs = Vector256.Narrow(lower, upper); + halfs.StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector256.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector256.Count * 2); + + Vector256 lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + Vector256 upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256.Count))); + Vector256.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } } - for (int i = 0; i < source.Length; i++) + if (Vector128.IsHardwareAccelerated) { - destination[i] = (float)source[i]; + twoVectorsFromEnd = source.Length - (Vector128.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector128 lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + Vector128 upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128.Count))); + Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector128.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector128.Count * 2); + + Vector128 lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + Vector128 upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128.Count))); + Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } } - } - private static bool IsNegative(float f) => float.IsNegative(f); + while (i < source.Length) + { + Unsafe.Add(ref destinationRef, i) = BitConverter.HalfToUInt16Bits((Half)Unsafe.Add(ref sourceRef, i)); + i++; + } - private static float MaxMagnitude(float x, float y) => MathF.MaxMagnitude(x, y); + // This implements a vectorized version of the `explicit operator Half(float value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/ca8d6f0420096831766ec11c7d400e4f7ccc7a34/src/libraries/System.Private.CoreLib/src/System/Half.cs#L606-L714 + // The cast operator converts a float to a Half represented as a UInt32, then narrows to a UInt16, and reinterpret casts to Half. + // This does the same, with an input VectorXx and an output VectorXx. + // Loop handling two input vectors at a time; each input float is double the size of each output Half, + // so we need two vectors of floats to produce one vector of Halfs. Half isn't supported in VectorXx, + // so we convert the VectorXx to a VectorXx, and the caller then uses this twice, narrows the combination + // into a VectorXx, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`. + +#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding + const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1 + const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask + const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2 + const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half + const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float +#pragma warning restore IDE0059 + + static Vector128 SingleToHalfAsWidenedUInt32_Vector128(Vector128 value) + { + Vector128 bitValue = value.AsUInt32(); - private static float MinMagnitude(float x, float y) => MathF.MinMagnitude(x, y); + // Extract sign bit + Vector128 sign = Vector128.ShiftRightLogical(bitValue & Vector128.Create(SingleSignMask), 16); - private static float Log2(float x) => MathF.Log2(x); + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector128 realMask = Vector128.Equals(value, value).AsUInt32(); - private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) - { - // Compute the same as: - // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y))) - // but only looping over each span once. + // Clear sign bit + value = Vector128.Abs(value); - float dotProduct = 0f; - float xSumOfSquares = 0f; - float ySumOfSquares = 0f; + // Rectify values that are Infinity in Half. + value = Vector128.Min(Vector128.Create(MaxHalfValueBelowInfinity), value); - int i = 0; + // Rectify lower exponent + Vector128 exponentOffset0 = Vector128.Max(value, Vector128.Create(MinExp).AsSingle()).AsUInt32(); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Extract exponent + exponentOffset0 &= Vector128.Create(SingleBiasedExponentMask); - Vector512 dotProductVector = Vector512.Zero; - Vector512 xSumOfSquaresVector = Vector512.Zero; - Vector512 ySumOfSquaresVector = Vector512.Zero; + // Add exponent by 13 + exponentOffset0 += Vector128.Create(Exponent13); - // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector512.Count; - do - { - Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); - Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + // Only exponent bits will be modified if NaN + Vector128 maskedHalfExponentForNaN = ~realMask & Vector128.Create(ExponentMask); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Subtract exponent by 126 + bitValue -= Vector128.Create(Exponent126); - // Sum the vector lanes into the scalar result. - dotProduct += Vector512.Sum(dotProductVector); - xSumOfSquares += Vector512.Sum(xSumOfSquaresVector); - ySumOfSquares += Vector512.Sum(ySumOfSquaresVector); - } - else -#endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector128 newExponent = Vector128.ShiftRightLogical(bitValue, 13); - Vector256 dotProductVector = Vector256.Zero; - Vector256 xSumOfSquaresVector = Vector256.Zero; - Vector256 ySumOfSquaresVector = Vector256.Zero; + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; - // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector256.Count; - do - { - Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); - Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Merge sign bit with possible NaN exponent + Vector128 signAndMaskedExponent = maskedHalfExponentForNaN | sign; - // Sum the vector lanes into the scalar result. - dotProduct += Vector256.Sum(dotProductVector); - xSumOfSquares += Vector256.Sum(xSumOfSquaresVector); - ySumOfSquares += Vector256.Sum(ySumOfSquaresVector); + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; + + // The final result + return bitValue; } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + + static Vector256 SingleToHalfAsWidenedUInt32_Vector256(Vector256 value) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + Vector256 bitValue = value.AsUInt32(); - Vector128 dotProductVector = Vector128.Zero; - Vector128 xSumOfSquaresVector = Vector128.Zero; - Vector128 ySumOfSquaresVector = Vector128.Zero; + // Extract sign bit + Vector256 sign = Vector256.ShiftRightLogical(bitValue & Vector256.Create(SingleSignMask), 16); - // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector128.Count; - do - { - Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); - Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector256 realMask = Vector256.Equals(value, value).AsUInt32(); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + // Clear sign bit + value = Vector256.Abs(value); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Rectify values that are Infinity in Half. + value = Vector256.Min(Vector256.Create(MaxHalfValueBelowInfinity), value); - // Sum the vector lanes into the scalar result. - dotProduct += Vector128.Sum(dotProductVector); - xSumOfSquares += Vector128.Sum(xSumOfSquaresVector); - ySumOfSquares += Vector128.Sum(ySumOfSquaresVector); - } + // Rectify lower exponent + Vector256 exponentOffset0 = Vector256.Max(value, Vector256.Create(MinExp).AsSingle()).AsUInt32(); - // Process any remaining elements past the last vector. - for (; (uint)i < (uint)x.Length; i++) - { - dotProduct += x[i] * y[i]; - xSumOfSquares += x[i] * x[i]; - ySumOfSquares += y[i] * y[i]; - } + // Extract exponent + exponentOffset0 &= Vector256.Create(SingleBiasedExponentMask); - // Sum(X * Y) / (|X| * |Y|) - return dotProduct / (MathF.Sqrt(xSumOfSquares) * MathF.Sqrt(ySumOfSquares)); - } + // Add exponent by 13 + exponentOffset0 += Vector256.Create(Exponent13); - private static float Aggregate( - float identityValue, ReadOnlySpan x) - where TLoad : IUnaryOperator - where TAggregate : IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); + // Only exponent bits will be modified if NaN + Vector256 maskedHalfExponentForNaN = ~realMask & Vector256.Create(ExponentMask); - // Load the first vector as the initial set of results - Vector512 resultVector = TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector512.Count; + // Subtract exponent by 126 + bitValue -= Vector256.Create(Exponent126); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector512.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i))); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector256 newExponent = Vector256.ShiftRightLogical(bitValue, 13); - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } - else -#endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; - // Load the first vector as the initial set of results - Vector256 resultVector = TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector256.Count; + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector256.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i))); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; + + // Merge sign bit with possible NaN exponent + Vector256 signAndMaskedExponent = maskedHalfExponentForNaN | sign; - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; + + // The final result + return bitValue; } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) + +#if NET8_0_OR_GREATER + static Vector512 SingleToHalfAsWidenedUInt32_Vector512(Vector512 value) { - ref float xRef = ref MemoryMarshal.GetReference(x); + Vector512 bitValue = value.AsUInt32(); - // Load the first vector as the initial set of results - Vector128 resultVector = TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector128.Count; + // Extract sign bit + Vector512 sign = Vector512.ShiftRightLogical(bitValue & Vector512.Create(SingleSignMask), 16); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector128.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i))); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector512 realMask = Vector512.Equals(value, value).AsUInt32(); - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } + // Clear sign bit + value = Vector512.Abs(value); - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = TAggregate.Invoke(result, TLoad.Invoke(x[i])); - } + // Rectify values that are Infinity in Half. + value = Vector512.Min(Vector512.Create(MaxHalfValueBelowInfinity), value); - return result; - } + // Rectify lower exponent + Vector512 exponentOffset0 = Vector512.Max(value, Vector512.Create(MinExp).AsSingle()).AsUInt32(); - private static float Aggregate( - float identityValue, ReadOnlySpan x, ReadOnlySpan y) - where TBinary : IBinaryOperator - where TAggregate : IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + // Extract exponent + exponentOffset0 &= Vector512.Create(SingleBiasedExponentMask); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Add exponent by 13 + exponentOffset0 += Vector512.Create(Exponent13); - // Load the first vector as the initial set of results - Vector512 resultVector = TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, 0), Vector512.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector512.Count; + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector512.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), Vector512.LoadUnsafe(ref yRef, (uint)i))); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Only exponent bits will be modified if NaN + Vector512 maskedHalfExponentForNaN = ~realMask & Vector512.Create(ExponentMask); - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } - else -#endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Subtract exponent by 126 + bitValue -= Vector512.Create(Exponent126); - // Load the first vector as the initial set of results - Vector256 resultVector = TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, 0), Vector256.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector256.Count; + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector512 newExponent = Vector512.ShiftRightLogical(bitValue, 13); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector256.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), Vector256.LoadUnsafe(ref yRef, (uint)i))); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; - // Load the first vector as the initial set of results - Vector128 resultVector = TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, 0), Vector128.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector128.Count; + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector128.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), Vector128.LoadUnsafe(ref yRef, (uint)i))); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Merge sign bit with possible NaN exponent + Vector512 signAndMaskedExponent = maskedHalfExponentForNaN | sign; - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = TAggregate.Invoke(result, TBinary.Invoke(x[i], y[i])); + // The final result + return bitValue; } - - return result; +#endif } - private static unsafe void InvokeSpanIntoSpan( - ReadOnlySpan x, Span destination) - where TUnaryOperator : IUnaryOperator + /// + /// Copies to , converting each half-precision + /// floating-point value to its nearest representable value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = (float)[i]. + /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// + /// + public static void ConvertToSingle(ReadOnlySpan source, Span destination) { - if (x.Length > destination.Length) + if (source.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref short sourceRef = ref Unsafe.As(ref MemoryMarshal.GetReference(source)); + ref float destinationRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; + oneVectorFromEnd = source.Length - Vector512.Count; if (i <= oneVectorFromEnd) { - // Loop handling one vector at a time. + // Loop handling one input vector / two output vectors at a time. do { - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); - i += Vector512.Count; + i += Vector512.Count; } while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Handle any remaining elements with a final input vector. + if (i != source.Length) { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + i = source.Length - Vector512.Count; + + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); } return; @@ -387,23 +407,28 @@ private static unsafe void InvokeSpanIntoSpan( if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; + oneVectorFromEnd = source.Length - Vector256.Count; if (i <= oneVectorFromEnd) { - // Loop handling one vector at a time. + // Loop handling one input vector / two output vectors at a time. do { - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); - i += Vector256.Count; + i += Vector256.Count; } while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Handle any remaining elements with a final input vector. + if (i != source.Length) { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + i = source.Length - Vector256.Count; + + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); } return; @@ -412,832 +437,11129 @@ private static unsafe void InvokeSpanIntoSpan( if (Vector128.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector128.Count; + oneVectorFromEnd = source.Length - Vector128.Count; if (i <= oneVectorFromEnd) { - // Loop handling one vector at a time. + // Loop handling one input vector / two output vectors at a time. do { - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); - i += Vector128.Count; + i += Vector128.Count; } while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Handle any remaining elements with a final input vector. + if (i != source.Length) { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + i = source.Length - Vector128.Count; + + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); } return; } } - while (i < x.Length) + while (i < source.Length) { - Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); - + Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As(ref Unsafe.Add(ref sourceRef, i)); i++; } - } - private static unsafe void InvokeSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, Span destination) - where TBinaryOperator : IBinaryOperator - { - if (x.Length != y.Length) + // This implements a vectorized version of the `explicit operator float(Half value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040 + // The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx and an output VectorXx. + // The VectorXx is created by reading a vector of Halfs as a VectorXx then widened to two VectorXxs and cast to VectorXxs. + // We loop handling one input vector at a time, producing two output float vectors. + +#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single + const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) + const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single + const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half +#pragma warning restore IDE0059 + + static Vector128 HalfAsWidenedUInt32ToSingle_Vector128(Vector128 value) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + // Extract sign bit of value + Vector128 sign = value & Vector128.Create(SingleSignMask); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + // Copy sign bit to upper bits + Vector128 bitValueInProcess = value; - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector128 offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + // ~0u when value is subnormal, 0 otherwise + Vector128 subnormalMask = Vector128.Equals(offsetExponent, Vector128.Zero); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector128 infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask)); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector128 maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound); - return; - } - } -#endif + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector128 offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound; - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128.Zero), + offsetMaskedExponentLowerBound, + Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1)); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask); - return; - } + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector128 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); } - if (Vector128.IsHardwareAccelerated) + static Vector256 HalfAsWidenedUInt32ToSingle_Vector256(Vector256 value) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + // Extract sign bit of value + Vector256 sign = value & Vector256.Create(SingleSignMask); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Copy sign bit to upper bits + Vector256 bitValueInProcess = value; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector256 offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask); - return; - } + // ~0u when value is subnormal, 0 otherwise + Vector256 subnormalMask = Vector256.Equals(offsetExponent, Vector256.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector256 infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector256 maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector256 offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256.Zero), + offsetMaskedExponentLowerBound, + Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector256 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); } - while (i < x.Length) +#if NET8_0_OR_GREATER + static Vector512 HalfAsWidenedUInt32ToSingle_Vector512(Vector512 value) { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + // Extract sign bit of value + Vector512 sign = value & Vector512.Create(SingleSignMask); - i++; + // Copy sign bit to upper bits + Vector512 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector512 offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector512 subnormalMask = Vector512.Equals(offsetExponent, Vector512.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector512 infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector512 maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector512 offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512.Zero), + offsetMaskedExponentLowerBound, + Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector512 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); } +#endif } - private static unsafe void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination) - where TBinaryOperator : IBinaryOperator + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// Assumes arguments have already been validated to be non-empty and equal length. + private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } - - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + // Compute the same as: + // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y))) + // but only looping over each span once. #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector512 dotProductVector = Vector512.Zero; + Vector512 xSumOfSquaresVector = Vector512.Zero; + Vector512 ySumOfSquaresVector = Vector512.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = 0; + do { - Vector512 yVec = Vector512.Create(y); + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec).StoreUnsafe(ref dRef, (uint)i); + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); - return; + Vector512 remainderMask = CreateRemainderMaskSingleVector512(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector512.Sum(dotProductVector) / + (MathF.Sqrt(Vector512.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector512.Sum(ySumOfSquaresVector))); } #endif - if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector256 dotProductVector = Vector256.Zero; + Vector256 xSumOfSquaresVector = Vector256.Zero; + Vector256 ySumOfSquaresVector = Vector256.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = 0; + do { - Vector256 yVec = Vector256.Create(y); + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec).StoreUnsafe(ref dRef, (uint)i); + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); - return; + Vector256 remainderMask = CreateRemainderMaskSingleVector256(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector256.Sum(dotProductVector) / + (MathF.Sqrt(Vector256.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector256.Sum(ySumOfSquaresVector))); } - if (Vector128.IsHardwareAccelerated) + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector128 dotProductVector = Vector128.Zero; + Vector128 xSumOfSquaresVector = Vector128.Zero; + Vector128 ySumOfSquaresVector = Vector128.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = 0; + do { - Vector128 yVec = Vector128.Create(y); + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec).StoreUnsafe(ref dRef, (uint)i); + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); - return; + Vector128 remainderMask = CreateRemainderMaskSingleVector128(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector128.Sum(dotProductVector) / + (MathF.Sqrt(Vector128.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector128.Sum(ySumOfSquaresVector))); } - while (i < x.Length) + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. + float dotProduct = 0f, xSumOfSquares = 0f, ySumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), - y); - - i++; + dotProduct = MathF.FusedMultiplyAdd(x[i], y[i], dotProduct); + xSumOfSquares = MathF.FusedMultiplyAdd(x[i], x[i], xSumOfSquares); + ySumOfSquares = MathF.FusedMultiplyAdd(y[i], y[i], ySumOfSquares); } + + // Sum(X * Y) / (|X| * |Y|) + return + dotProduct / + (MathF.Sqrt(xSumOfSquares) * MathF.Sqrt(ySumOfSquares)); } - private static unsafe void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) - where TTernaryOperator : ITernaryOperator + /// Performs an aggregation over all elements in to produce a single-precision floating-point value. + /// Specifies the transform operation that should be applied to each element loaded from . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied after the transform is applied to each element. + /// + private static float Aggregate( + ReadOnlySpan x) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator { - if (x.Length != y.Length || x.Length != z.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + float result; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + result = Vectorized512Small(ref xRef, remainder); } + + return result; } #endif if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + float result; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + result = Vectorized256Small(ref xRef, remainder); } + + return result; } if (Vector128.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + float result; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + result = Vectorized128Small(ref xRef, remainder); } + + return result; } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + // This is the software fallback when no acceleration is available. + // It requires no branches to hit. - i++; - } - } + return SoftwareFallback(ref xRef, remainder); - private static unsafe void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) - where TTernaryOperator : ITernaryOperator - { - if (x.Length != y.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, i))); + } + + return result; } - if (x.Length > destination.Length) + static float Vectorized128(ref float xRef, nuint remainder) { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + // Preload the beginning and end so that overlapping accesses don't negatively impact the data -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) { - Vector512 zVec = Vector512.Create(z); + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - // Loop handling one vector at a time. - do + fixed (float* px = &xRef) { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; - return; + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } } - } -#endif - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) { - Vector256 zVec = Vector256.Create(z); + case 7: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } - // Loop handling one vector at a time. - do + case 6: { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } - i += Vector256.Count; + case 5: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 4: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; } - return; - } + case 3: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static float Aggregate( + ReadOnlySpan x, ReadOnlySpan y) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized512Small(ref xRef, ref yRef, remainder); + } + + return result; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized256Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + if (Vector128.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized128Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length) + { + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i))); + } + + return result; + } + + static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) + { + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// + /// This is the same as + /// with an identity transform, except it early exits on NaN. + /// + private static float MinMaxCore(ReadOnlySpan x) + where TMinMaxOperator : struct, IAggregationOperator + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef, 0), current; + if (!Vector512.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + if (!Vector512.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + if (!Vector512.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector512.ConditionalSelect( + Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef, 0), current; + if (!Vector256.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + if (!Vector256.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + if (!Vector256.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector256.ConditionalSelect( + Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef, 0), current; + if (!Vector128.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + if (!Vector128.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + if (!Vector128.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector128.ConditionalSelect( + Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + { + float result = x[0]; + if (float.IsNaN(result)) + { + return result; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return current; + } + + result = TMinMaxOperator.Invoke(result, current); + } + + return result; + } + } + + private static int IndexOfMinMaxCore(ReadOnlySpan x) where TIndexOfMinMax : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector512 resultIndex = Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + Vector512 curIndex = resultIndex; + Vector512 increment = Vector512.Create(Vector512.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef); + Vector512 current; + + Vector512 nanMask = ~Vector512.Equals(result, result); + if (nanMask != Vector512.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + curIndex += Vector512.Create(x.Length - i); + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector256 resultIndex = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7); + Vector256 curIndex = resultIndex; + Vector256 increment = Vector256.Create(Vector256.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef); + Vector256 current; + + Vector256 nanMask = ~Vector256.Equals(result, result); + if (nanMask != Vector256.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + curIndex += Vector256.Create(x.Length - i); + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector128 resultIndex = Vector128.Create(0, 1, 2, 3); + Vector128 curIndex = resultIndex; + Vector128 increment = Vector128.Create(Vector128.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef); + Vector128 current; + + Vector128 nanMask = ~Vector128.Equals(result, result); + if (nanMask != Vector128.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex += Vector128.Create(x.Length - i); + + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + float curResult = x[0]; + int curIn = 0; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + private static int IndexOfFirstMatch(Vector128 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + + private static int IndexOfFirstMatch(Vector256 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + +#if NET8_0_OR_GREATER + private static int IndexOfFirstMatch(Vector512 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } +#endif + + /// Performs an element-wise operation on and writes the results to . + /// Specifies the operation to perform on each element loaded from . + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and . + /// + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), + y); + } + } + + static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), + yVec); + Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))), + yVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + Vector256.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 zVec = Vector512.Create(z); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + zVec); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + zVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + yVec, + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128 y, Vector128 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + if (AdvSimd.IsSupported) + { + return AdvSimd.FusedMultiplyAdd(addend, x, y); + } + + return (x * y) + addend; + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256 y, Vector256 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } + +#if NET8_0_OR_GREATER + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512 y, Vector512 addend) + { + if (Avx512F.IsSupported) + { + return Avx512F.FusedMultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } +#endif + + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator + { + // We need to do log2(count) operations to compute the total sum + + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(2, 3, 0, 1))); + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(1, 0, 3, 2))); + + return x.ToScalar(); + } + + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector256 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); + +#if NET8_0_OR_GREATER + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector512 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); +#endif + + /// Gets whether the specified is negative. + private static bool IsNegative(float f) => float.IsNegative(f); + + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsNegative(Vector128 vector) => + Vector128.LessThan(vector.AsInt32(), Vector128.Zero).AsSingle(); + + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsNegative(Vector256 vector) => + Vector256.LessThan(vector.AsInt32(), Vector256.Zero).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsNegative(Vector512 vector) => + Vector512.LessThan(vector.AsInt32(), Vector512.Zero).AsSingle(); +#endif + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => float.IsPositive(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsPositive(Vector128 vector) => + Vector128.GreaterThan(vector.AsInt32(), Vector128.AllBitsSet).AsSingle(); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsPositive(Vector256 vector) => + Vector256.GreaterThan(vector.AsInt32(), Vector256.AllBitsSet).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsPositive(Vector512 vector) => + Vector512.GreaterThan(vector.AsInt32(), Vector512.AllBitsSet).AsSingle(); +#endif + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector128 vector) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector128 vector, Vector128 index) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector256 vector) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector256 vector, Vector256 index) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + +#if NET8_0_OR_GREATER + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector512 vector) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector512 vector, Vector512 index) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } +#endif + + /// Gets the base 2 logarithm of . + private static float Log2(float x) => MathF.Log2(x); + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateAlignmentMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateAlignmentMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateAlignmentMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateRemainderMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((count * 16) + 12)); // last four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateRemainderMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((count * 16) + 8)); // last eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateRemainderMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + + /// x + y + private readonly struct AddOperator : IAggregationOperator + { + public static float Invoke(float x, float y) => x + y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; +#endif + + public static float Invoke(Vector128 x) => Vector128.Sum(x); + public static float Invoke(Vector256 x) => Vector256.Sum(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => Vector512.Sum(x); +#endif + + public static float IdentityValue => 0; + } + + /// x - y + private readonly struct SubtractOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x - y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; +#endif + } + + /// (x - y) * (x - y) + private readonly struct SubtractSquaredOperator : IBinaryOperator + { + public static float Invoke(float x, float y) + { + float tmp = x - y; + return tmp * tmp; + } + + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 tmp = x - y; + return tmp * tmp; + } + + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 tmp = x - y; + return tmp * tmp; + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 tmp = x - y; + return tmp * tmp; + } +#endif + } + + /// x * y + private readonly struct MultiplyOperator : IAggregationOperator + { + public static float Invoke(float x, float y) => x * y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + + public static float IdentityValue => 1; + } + + /// x / y + private readonly struct DivideOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x / y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; +#endif + } + + /// MathF.Max(x, y) (but NaNs may not be propagated) + private readonly struct MaxOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => + x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Max(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)); +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + private interface IIndexOfOperator + { + static abstract int Invoke(ref float result, float current, int resultIndex, int curIndex); + static abstract int Invoke(Vector128 result, Vector128 resultIndex); + static abstract void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex); + static abstract int Invoke(Vector256 result, Vector256 resultIndex); + static abstract void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex); +#if NET8_0_OR_GREATER + static abstract int Invoke(Vector512 result, Vector512 resultIndex); + static abstract void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex); +#endif + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 greaterThanMask = Vector128.GreaterThan(max, current); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 greaterThanMask = Vector256.GreaterThan(max, current); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 greaterThanMask = Vector512.GreaterThan(max, current); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (current > result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 maxMag = Vector128.Abs(max), currentMag = Vector128.Abs(current); + + Vector128 greaterThanMask = Vector128.GreaterThan(maxMag, currentMag); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 maxMag = Vector256.Abs(max), currentMag = Vector256.Abs(current); + + Vector256 greaterThanMask = Vector256.GreaterThan(maxMag, currentMag); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 maxMag = Vector512.Abs(max), currentMag = Vector512.Abs(current); + Vector512 greaterThanMask = Vector512.GreaterThan(maxMag, currentMag); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// Returns the index of MathF.Min(x, y) + private readonly struct IndexOfMinOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 lessThanMask = Vector128.LessThan(result, current); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 lessThanMask = Vector256.LessThan(result, current); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 lessThanMask = Vector512.LessThan(result, current); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 minMag = Vector128.Abs(result), currentMag = Vector128.Abs(current); + + Vector128 lessThanMask = Vector128.LessThan(minMag, currentMag); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 minMag = Vector256.Abs(result), currentMag = Vector256.Abs(current); + + Vector256 lessThanMask = Vector256.LessThan(minMag, currentMag); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 minMag = Vector512.Abs(result), currentMag = Vector512.Abs(current); + + Vector512 lessThanMask = Vector512.LessThan(minMag, currentMag); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// MathF.Max(x, y) + private readonly struct MaxPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.Max(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Max(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)), + y), + x); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)), + y), + x); +#endif + } + + /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) + private readonly struct MaxMagnitudeOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + xMag == yMag ? + (IsNegative(x) ? y : x) : + (xMag > yMag ? x : y); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(xMag, yMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(xMag, yMag), x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)); + } +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// Operator to get x or y based on which has the larger MathF.Abs + private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.MaxMagnitude(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(yMag, xMag), y, x)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)), + y), + x); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)), + y), + x); + } +#endif + } + + /// MathF.Min(x, y) (but NaNs may not be propagated) + private readonly struct MinOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => + x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Min(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.Min(x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.Min(x, y)); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.Min(x, y)); +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// MathF.Min(x, y) + private readonly struct MinPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.Min(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Min(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.Min(x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.Min(x, y)), + y), + x); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.Min(x, y)), + y), + x); +#endif + } + + /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) + private readonly struct MinMagnitudeOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag == yMag ? + (IsNegative(y) ? y : x) : + (yMag < xMag ? y : x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.ConditionalSelect(Vector128.LessThan(yMag, xMag), y, x)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.ConditionalSelect(Vector256.LessThan(yMag, xMag), y, x)); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.ConditionalSelect(Vector512.LessThan(yMag, xMag), y, x)); + } +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// Operator to get x or y based on which has the smaller MathF.Abs + private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.MinMagnitude(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.ConditionalSelect(Vector128.LessThan(xMag, yMag), x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.ConditionalSelect(Vector256.LessThan(xMag, yMag), x, y)), + y), + x); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.ConditionalSelect(Vector512.LessThan(xMag, yMag), x, y)), + y), + x); + } +#endif + } + + /// -x + private readonly struct NegateOperator : IUnaryOperator + { + public static float Invoke(float x) => -x; + public static Vector128 Invoke(Vector128 x) => -x; + public static Vector256 Invoke(Vector256 x) => -x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => -x; +#endif + } + + /// (x + y) * z + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x + y) * z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; +#endif + } + + /// (x * y) + z + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x * y) + z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; +#endif + } + + /// x + private readonly struct IdentityOperator : IUnaryOperator + { + public static float Invoke(float x) => x; + public static Vector128 Invoke(Vector128 x) => x; + public static Vector256 Invoke(Vector256 x) => x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => x; +#endif + } + + /// x * x + private readonly struct SquaredOperator : IUnaryOperator + { + public static float Invoke(float x) => x * x; + public static Vector128 Invoke(Vector128 x) => x * x; + public static Vector256 Invoke(Vector256 x) => x * x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => x * x; +#endif + } + + /// MathF.Abs(x) + private readonly struct AbsoluteOperator : IUnaryOperator + { + public static float Invoke(float x) => MathF.Abs(x); + public static Vector128 Invoke(Vector128 x) => Vector128.Abs(x); + public static Vector256 Invoke(Vector256 x) => Vector256.Abs(x); +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => Vector512.Abs(x); +#endif + } + + /// MathF.Exp(x) + private readonly struct ExpOperator : IUnaryOperator + { + // This code is based on `vrs4_expf` from amd/aocl-libm-ose + // Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Implementation Notes: + // 1. Argument Reduction: + // e^x = 2^(x/ln2) --- (1) + // + // Let x/ln(2) = z --- (2) + // + // Let z = n + r , where n is an integer --- (3) + // |r| <= 1/2 + // + // From (1), (2) and (3), + // e^x = 2^z + // = 2^(N+r) + // = (2^N)*(2^r) --- (4) + // + // 2. Polynomial Evaluation + // From (4), + // r = z - N + // 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5 + // + // 4. Reconstruction + // Thus, + // e^x = (2^N) * (2^r) + + private const uint V_ARG_MAX = 0x42AE0000; + private const uint V_MASK = 0x7FFFFFFF; + + private const float V_EXPF_MIN = -103.97208f; + private const float V_EXPF_MAX = 88.72284f; + + private const double V_EXPF_HUGE = 6755399441055744; + private const double V_TBL_LN2 = 1.4426950408889634; + + private const double C1 = 1.0000000754895704; + private const double C2 = 0.6931472254087585; + private const double C3 = 0.2402210737432219; + private const double C4 = 0.05550297297702539; + private const double C5 = 0.009676036358193323; + private const double C6 = 0.001341000536524434; + + public static float Invoke(float x) => MathF.Exp(x); + + public static Vector128 Invoke(Vector128 x) + { + // Convert x to double precision + (Vector128 xl, Vector128 xu) = Vector128.Widen(x); + + // x * (64.0 / ln(2)) + Vector128 v_tbl_ln2 = Vector128.Create(V_TBL_LN2); + + Vector128 zl = xl * v_tbl_ln2; + Vector128 zu = xu * v_tbl_ln2; + + Vector128 v_expf_huge = Vector128.Create(V_EXPF_HUGE); + + Vector128 dnl = zl + v_expf_huge; + Vector128 dnu = zu + v_expf_huge; + + // n = int (z) + Vector128 nl = dnl.AsUInt64(); + Vector128 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector128 c1 = Vector128.Create(C1); + Vector128 c2 = Vector128.Create(C2); + Vector128 c3 = Vector128.Create(C3); + Vector128 c4 = Vector128.Create(C4); + Vector128 c5 = Vector128.Create(C5); + Vector128 c6 = Vector128.Create(C6); + + Vector128 rl = zl - dnl; + + Vector128 rl2 = rl * rl; + Vector128 rl4 = rl2 * rl2; + + Vector128 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector128 ru = zu - dnu; + + Vector128 ru2 = ru * ru; + Vector128 ru4 = ru2 * ru2; + + Vector128 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector128 ret = Vector128.Narrow( + (polyl.AsUInt64() + Vector128.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector128.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector128.GreaterThanAny(x.AsUInt32() & Vector128.Create(V_MASK), Vector128.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector128 infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX)); + + ret = Vector128.ConditionalSelect( + infinityMask, + Vector128.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN))); + } + + return ret; + } + + public static Vector256 Invoke(Vector256 x) + { + // Convert x to double precision + (Vector256 xl, Vector256 xu) = Vector256.Widen(x); + + // x * (64.0 / ln(2)) + Vector256 v_tbl_ln2 = Vector256.Create(V_TBL_LN2); + + Vector256 zl = xl * v_tbl_ln2; + Vector256 zu = xu * v_tbl_ln2; + + Vector256 v_expf_huge = Vector256.Create(V_EXPF_HUGE); + + Vector256 dnl = zl + v_expf_huge; + Vector256 dnu = zu + v_expf_huge; + + // n = int (z) + Vector256 nl = dnl.AsUInt64(); + Vector256 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector256 c1 = Vector256.Create(C1); + Vector256 c2 = Vector256.Create(C2); + Vector256 c3 = Vector256.Create(C3); + Vector256 c4 = Vector256.Create(C4); + Vector256 c5 = Vector256.Create(C5); + Vector256 c6 = Vector256.Create(C6); + + Vector256 rl = zl - dnl; + + Vector256 rl2 = rl * rl; + Vector256 rl4 = rl2 * rl2; + + Vector256 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector256 ru = zu - dnu; + + Vector256 ru2 = ru * ru; + Vector256 ru4 = ru2 * ru2; + + Vector256 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector256 ret = Vector256.Narrow( + (polyl.AsUInt64() + Vector256.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector256.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector256.GreaterThanAny(x.AsUInt32() & Vector256.Create(V_MASK), Vector256.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector256 infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX)); + + ret = Vector256.ConditionalSelect( + infinityMask, + Vector256.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN))); + } + + return ret; + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + // Convert x to double precision + (Vector512 xl, Vector512 xu) = Vector512.Widen(x); + + // x * (64.0 / ln(2)) + Vector512 v_tbl_ln2 = Vector512.Create(V_TBL_LN2); + + Vector512 zl = xl * v_tbl_ln2; + Vector512 zu = xu * v_tbl_ln2; + + Vector512 v_expf_huge = Vector512.Create(V_EXPF_HUGE); + + Vector512 dnl = zl + v_expf_huge; + Vector512 dnu = zu + v_expf_huge; + + // n = int (z) + Vector512 nl = dnl.AsUInt64(); + Vector512 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector512 c1 = Vector512.Create(C1); + Vector512 c2 = Vector512.Create(C2); + Vector512 c3 = Vector512.Create(C3); + Vector512 c4 = Vector512.Create(C4); + Vector512 c5 = Vector512.Create(C5); + Vector512 c6 = Vector512.Create(C6); + + Vector512 rl = zl - dnl; + + Vector512 rl2 = rl * rl; + Vector512 rl4 = rl2 * rl2; + + Vector512 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector512 ru = zu - dnu; + + Vector512 ru2 = ru * ru; + Vector512 ru4 = ru2 * ru2; + + Vector512 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector512 ret = Vector512.Narrow( + (polyl.AsUInt64() + Vector512.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector512.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector512.GreaterThanAny(x.AsUInt32() & Vector512.Create(V_MASK), Vector512.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector512 infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX)); + + ret = Vector512.ConditionalSelect( + infinityMask, + Vector512.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN))); + } + + return ret; + } +#endif + } + + /// MathF.Cosh(x) + private readonly struct CoshOperator : IUnaryOperator + { + // This code is based on `vrs4_coshf` from amd/aocl-libm-ose + // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // coshf(|x| > 89.415985107421875) = Infinity + // coshf(Infinity) = infinity + // coshf(-Infinity) = infinity + // + // cosh(x) = (exp(x) + exp(-x))/2 + // cosh(-x) = +cosh(x) + // + // checks for special cases + // if ( asint(x) > infinity) return x with overflow exception and + // return x. + // if x is NaN then raise invalid FP operation exception and return x. + // + // coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1 + + private const float LOGV = 0.693161f; + private const float HALFV = 1.0000138f; + private const float INVV2 = 0.24999309f; + + public static float Invoke(float x) => MathF.Cosh(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + return Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z)); + } + + public static Vector256 Invoke(Vector256 x) + { + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + return Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z)); } - if (Vector128.IsHardwareAccelerated) +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - Vector128 zVec = Vector128.Create(z); + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + return Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z)); + } +#endif + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + /// MathF.Sinh(x) + private readonly struct SinhOperator : IUnaryOperator + { + // Same as cosh, but with `z -` rather than `z +`, and with the sign + // flipped on the result based on the sign of the input. - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + private const uint SIGN_MASK = 0x7FFFFFFF; + private const float LOGV = 0.693161f; + private const float HALFV = 1.0000138f; + private const float INVV2 = 0.24999309f; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + public static float Invoke(float x) => MathF.Sinh(x); - return; - } + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + Vector128 result = Vector128.Create(HALFV) * (z - (Vector128.Create(INVV2) / z)); + Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); } - while (i < x.Length) + public static Vector256 Invoke(Vector256 x) { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + Vector256 result = Vector256.Create(HALFV) * (z - (Vector256.Create(INVV2) / z)); + Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); + } - i++; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + Vector512 result = Vector512.Create(HALFV) * (z - (Vector512.Create(INVV2) / z)); + Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); } +#endif } - private static unsafe void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) - where TTernaryOperator : ITernaryOperator + /// MathF.Tanh(x) + private readonly struct TanhOperator : IUnaryOperator { - if (x.Length != z.Length) + // This code is based on `vrs4_tanhf` from amd/aocl-libm-ose + // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // To compute vrs4_tanhf(v_f32x4_t x) + // Let y = |x| + // If 0 <= y < 0x1.154246p3 + // Let z = e^(-2.0 * y) - 1 -(1) + // + // Using (1), tanhf(y) can be calculated as, + // tanhf(y) = -z / (z + 2.0) + // + // For other cases, call scalar tanhf() + // + // If x < 0, then we use the identity + // tanhf(-x) = -tanhf(x) + + private const uint SIGN_MASK = 0x7FFFFFFF; + + public static float Invoke(float x) => MathF.Tanh(x); + + public static Vector128 Invoke(Vector128 x) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f); + Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle(); } - if (x.Length > destination.Length) + public static Vector256 Invoke(Vector256 x) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f); + Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle(); } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; - #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) + public static Vector512 Invoke(Vector512 x) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 yVec = Vector512.Create(y); + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f); + Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle(); + } +#endif + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + /// MathF.Log(x) + private readonly struct LogOperator : IUnaryOperator + { + // This code is based on `vrs4_logf` from amd/aocl-libm-ose + // Copyright (C) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // logf(x) + // = logf(x) if x ∈ F and x > 0 + // = x if x = qNaN + // = 0 if x = 1 + // = -inf if x = (-0, 0} + // = NaN otherwise + // + // Assumptions/Expectations + // - ULP is derived to be << 4 (always) + // - Some FPU Exceptions may not be available + // - Performance is at least 3x + // + // Implementation Notes: + // 1. Range Reduction: + // x = 2^n*(1+f) .... (1) + // where n is exponent and is an integer + // (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2) + // + // From (1), taking log on both sides + // log(x) = log(2^n * (1+f)) + // = log(2^n) + log(1+f) + // = n*log(2) + log(1+f) .... (3) + // + // let z = 1 + f + // log(z) = log(k) + log(z) - log(k) + // log(z) = log(kz) - log(k) + // + // From (2), range of z is [1, 2) + // by simply dividing range by 'k', z is in [1/k, 2/k) .... (4) + // Best choice of k is the one which gives equal and opposite values + // at extrema +- -+ + // 1 | 2 | + // --- - 1 = - |--- - 1 | + // k | k | .... (5) + // +- -+ + // + // Solving for k, k = 3/2, + // From (4), using 'k' value, range is therefore [-0.3333, 0.3333] + // + // 2. Polynomial Approximation: + // More information refer to tools/sollya/vrs4_logf.sollya + // + // 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19 + // 6th Deg - Error abs: 0x1.179e97d8p-19 rel: 0x1.db676c1p-17 + + private const uint V_MIN = 0x00800000; + private const uint V_MAX = 0x7F800000; + private const uint V_MASK = 0x007FFFFF; + private const uint V_OFF = 0x3F2AAAAB; + + private const float V_LN2 = 0.6931472f; + + private const float C0 = 0.0f; + private const float C1 = 1.0f; + private const float C2 = -0.5000001f; + private const float C3 = 0.33332965f; + private const float C4 = -0.24999046f; + private const float C5 = 0.20018855f; + private const float C6 = -0.16700386f; + private const float C7 = 0.13902695f; + private const float C8 = -0.1197452f; + private const float C9 = 0.14401625f; + private const float C10 = -0.13657966f; + + public static float Invoke(float x) => MathF.Log(x); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + public static Vector128 Invoke(Vector128 x) + { + Vector128 specialResult = x; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // x is subnormal or infinity or NaN + Vector128 specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN)); - return; + if (specialMask != Vector128.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector128 zeroMask = Vector128.Equals(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + zeroMask, + Vector128.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector128 lessThanZeroMask = Vector128.LessThan(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + lessThanZeroMask, + Vector128.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector128 temp = zeroMask + | lessThanZeroMask + | ~Vector128.Equals(x, x) + | Vector128.Equals(x, Vector128.Create(float.PositiveInfinity)); + + // subnormal + Vector128 subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp); + + x = Vector128.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - } -#endif - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 yVec = Vector256.Create(y); + Vector128 vx = x.AsUInt32() - Vector128.Create(V_OFF); + Vector128 n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23)); - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + Vector128 r = vx.AsSingle() - Vector128.Create(1.0f); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vector128 r2 = r * r; + Vector128 r4 = r2 * r2; + Vector128 r8 = r4 * r4; - return; - } + Vector128 q = (Vector128.Create(C10) * r2 + (Vector128.Create(C9) * r + Vector128.Create(C8))) + * r8 + (((Vector128.Create(C7) * r + Vector128.Create(C6)) + * r2 + (Vector128.Create(C5) * r + Vector128.Create(C4))) + * r4 + ((Vector128.Create(C3) * r + Vector128.Create(C2)) + * r2 + (Vector128.Create(C1) * r + Vector128.Create(C0)))); + + return Vector128.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector128.Create(V_LN2) + q + ); } - if (Vector128.IsHardwareAccelerated) + public static Vector256 Invoke(Vector256 x) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + Vector256 specialResult = x; + + // x is subnormal or infinity or NaN + Vector256 specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN)); + + if (specialMask != Vector256.Zero) { - Vector128 yVec = Vector128.Create(y); + // float.IsZero(x) ? float.NegativeInfinity : x + Vector256 zeroMask = Vector256.Equals(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + zeroMask, + Vector256.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector256 lessThanZeroMask = Vector256.LessThan(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + lessThanZeroMask, + Vector256.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector256 temp = zeroMask + | lessThanZeroMask + | ~Vector256.Equals(x, x) + | Vector256.Equals(x, Vector256.Create(float.PositiveInfinity)); + + // subnormal + Vector256 subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp); + + x = Vector256.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vector256 vx = x.AsUInt32() - Vector256.Create(V_OFF); + Vector256 n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23)); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vector256 r = vx.AsSingle() - Vector256.Create(1.0f); - return; - } - } + Vector256 r2 = r * r; + Vector256 r4 = r2 * r2; + Vector256 r8 = r4 * r4; - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + Vector256 q = (Vector256.Create(C10) * r2 + (Vector256.Create(C9) * r + Vector256.Create(C8))) + * r8 + (((Vector256.Create(C7) * r + Vector256.Create(C6)) + * r2 + (Vector256.Create(C5) * r + Vector256.Create(C4))) + * r4 + ((Vector256.Create(C3) * r + Vector256.Create(C2)) + * r2 + (Vector256.Create(C1) * r + Vector256.Create(C0)))); - i++; + return Vector256.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector256.Create(V_LN2) + q + ); } - } - private readonly struct AddOperator : IBinaryOperator - { - public static float Invoke(float x, float y) => x + y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; -#endif + public static Vector512 Invoke(Vector512 x) + { + Vector512 specialResult = x; - public static float Invoke(Vector128 x) => Vector128.Sum(x); - public static float Invoke(Vector256 x) => Vector256.Sum(x); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => Vector512.Sum(x); -#endif - } + // x is subnormal or infinity or NaN + Vector512 specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN)); - private readonly struct SubtractOperator : IBinaryOperator - { - public static float Invoke(float x, float y) => x - y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; -#endif + if (specialMask != Vector512.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector512 zeroMask = Vector512.Equals(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + zeroMask, + Vector512.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector512 lessThanZeroMask = Vector512.LessThan(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + lessThanZeroMask, + Vector512.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector512 temp = zeroMask + | lessThanZeroMask + | ~Vector512.Equals(x, x) + | Vector512.Equals(x, Vector512.Create(float.PositiveInfinity)); + + // subnormal + Vector512 subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp); + + x = Vector512.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } - public static float Invoke(Vector128 x) => throw new NotSupportedException(); - public static float Invoke(Vector256 x) => throw new NotSupportedException(); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => throw new NotSupportedException(); -#endif - } + Vector512 vx = x.AsUInt32() - Vector512.Create(V_OFF); + Vector512 n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23)); - private readonly struct SubtractSquaredOperator : IBinaryOperator - { - public static float Invoke(float x, float y) - { - float tmp = x - y; - return tmp * tmp; - } + vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF); - public static Vector128 Invoke(Vector128 x, Vector128 y) - { - Vector128 tmp = x - y; - return tmp * tmp; - } + Vector512 r = vx.AsSingle() - Vector512.Create(1.0f); - public static Vector256 Invoke(Vector256 x, Vector256 y) - { - Vector256 tmp = x - y; - return tmp * tmp; - } + Vector512 r2 = r * r; + Vector512 r4 = r2 * r2; + Vector512 r8 = r4 * r4; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) - { - Vector512 tmp = x - y; - return tmp * tmp; - } -#endif + Vector512 q = (Vector512.Create(C10) * r2 + (Vector512.Create(C9) * r + Vector512.Create(C8))) + * r8 + (((Vector512.Create(C7) * r + Vector512.Create(C6)) + * r2 + (Vector512.Create(C5) * r + Vector512.Create(C4))) + * r4 + ((Vector512.Create(C3) * r + Vector512.Create(C2)) + * r2 + (Vector512.Create(C1) * r + Vector512.Create(C0)))); - public static float Invoke(Vector128 x) => throw new NotSupportedException(); - public static float Invoke(Vector256 x) => throw new NotSupportedException(); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => throw new NotSupportedException(); + return Vector512.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector512.Create(V_LN2) + q + ); + } #endif } - private readonly struct MultiplyOperator : IBinaryOperator + /// MathF.Log2(x) + private readonly struct Log2Operator : IUnaryOperator { - public static float Invoke(float x, float y) => x * y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; -#endif + // This code is based on `vrs4_log2f` from amd/aocl-libm-ose + // Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // log2f(x) + // = log2f(x) if x ∈ F and x > 0 + // = x if x = qNaN + // = 0 if x = 1 + // = -inf if x = (-0, 0} + // = NaN otherwise + // + // Assumptions/Expectations + // - Maximum ULP is observed to be at 4 + // - Some FPU Exceptions may not be available + // - Performance is at least 3x + // + // Implementation Notes: + // 1. Range Reduction: + // x = 2^n*(1+f) .... (1) + // where n is exponent and is an integer + // (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2) + // + // From (1), taking log on both sides + // log2(x) = log2(2^n * (1+f)) + // = n + log2(1+f) .... (3) + // + // let z = 1 + f + // log2(z) = log2(k) + log2(z) - log2(k) + // log2(z) = log2(kz) - log2(k) + // + // From (2), range of z is [1, 2) + // by simply dividing range by 'k', z is in [1/k, 2/k) .... (4) + // Best choice of k is the one which gives equal and opposite values + // at extrema +- -+ + // 1 | 2 | + // --- - 1 = - |--- - 1 | + // k | k | .... (5) + // +- -+ + // + // Solving for k, k = 3/2, + // From (4), using 'k' value, range is therefore [-0.3333, 0.3333] + // + // 2. Polynomial Approximation: + // More information refer to tools/sollya/vrs4_logf.sollya + // + // 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19 + + private const uint V_MIN = 0x00800000; + private const uint V_MAX = 0x7F800000; + private const uint V_MASK = 0x007FFFFF; + private const uint V_OFF = 0x3F2AAAAB; + + private const float C0 = 0.0f; + private const float C1 = 1.4426951f; + private const float C2 = -0.72134554f; + private const float C3 = 0.48089063f; + private const float C4 = -0.36084408f; + private const float C5 = 0.2888971f; + private const float C6 = -0.23594281f; + private const float C7 = 0.19948183f; + private const float C8 = -0.22616665f; + private const float C9 = 0.21228963f; + + public static float Invoke(float x) => MathF.Log2(x); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(Vector128 x) + public static Vector128 Invoke(Vector128 x) { - float f = x[0]; - for (int i = 1; i < Vector128.Count; i++) + Vector128 specialResult = x; + + // x is subnormal or infinity or NaN + Vector128 specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN)); + + if (specialMask != Vector128.Zero) { - f *= x[i]; + // float.IsZero(x) ? float.NegativeInfinity : x + Vector128 zeroMask = Vector128.Equals(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + zeroMask, + Vector128.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector128 lessThanZeroMask = Vector128.LessThan(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + lessThanZeroMask, + Vector128.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector128 temp = zeroMask + | lessThanZeroMask + | ~Vector128.Equals(x, x) + | Vector128.Equals(x, Vector128.Create(float.PositiveInfinity)); + + // subnormal + Vector128 subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp); + + x = Vector128.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - return f; + + Vector128 vx = x.AsUInt32() - Vector128.Create(V_OFF); + Vector128 n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF); + + Vector128 r = vx.AsSingle() - Vector128.Create(1.0f); + + Vector128 r2 = r * r; + Vector128 r4 = r2 * r2; + Vector128 r8 = r4 * r4; + + Vector128 poly = (Vector128.Create(C9) * r + Vector128.Create(C8)) * r8 + + (((Vector128.Create(C7) * r + Vector128.Create(C6)) * r2 + + (Vector128.Create(C5) * r + Vector128.Create(C4))) * r4 + + ((Vector128.Create(C3) * r + Vector128.Create(C2)) * r2 + + (Vector128.Create(C1) * r + Vector128.Create(C0)))); + + return Vector128.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 x) { - float f = x[0]; - for (int i = 1; i < Vector256.Count; i++) + Vector256 specialResult = x; + + // x is subnormal or infinity or NaN + Vector256 specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN)); + + if (specialMask != Vector256.Zero) { - f *= x[i]; + // float.IsZero(x) ? float.NegativeInfinity : x + Vector256 zeroMask = Vector256.Equals(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + zeroMask, + Vector256.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector256 lessThanZeroMask = Vector256.LessThan(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + lessThanZeroMask, + Vector256.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector256 temp = zeroMask + | lessThanZeroMask + | ~Vector256.Equals(x, x) + | Vector256.Equals(x, Vector256.Create(float.PositiveInfinity)); + + // subnormal + Vector256 subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp); + + x = Vector256.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - return f; + + Vector256 vx = x.AsUInt32() - Vector256.Create(V_OFF); + Vector256 n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF); + + Vector256 r = vx.AsSingle() - Vector256.Create(1.0f); + + Vector256 r2 = r * r; + Vector256 r4 = r2 * r2; + Vector256 r8 = r4 * r4; + + Vector256 poly = (Vector256.Create(C9) * r + Vector256.Create(C8)) * r8 + + (((Vector256.Create(C7) * r + Vector256.Create(C6)) * r2 + + (Vector256.Create(C5) * r + Vector256.Create(C4))) * r4 + + ((Vector256.Create(C3) * r + Vector256.Create(C2)) * r2 + + (Vector256.Create(C1) * r + Vector256.Create(C0)))); + + return Vector256.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); } #if NET8_0_OR_GREATER - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 x) { - float f = x[0]; - for (int i = 1; i < Vector512.Count; i++) + Vector512 specialResult = x; + + // x is subnormal or infinity or NaN + Vector512 specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN)); + + if (specialMask != Vector512.Zero) { - f *= x[i]; + // float.IsZero(x) ? float.NegativeInfinity : x + Vector512 zeroMask = Vector512.Equals(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + zeroMask, + Vector512.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector512 lessThanZeroMask = Vector512.LessThan(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + lessThanZeroMask, + Vector512.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector512 temp = zeroMask + | lessThanZeroMask + | ~Vector512.Equals(x, x) + | Vector512.Equals(x, Vector512.Create(float.PositiveInfinity)); + + // subnormal + Vector512 subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp); + + x = Vector512.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - return f; + + Vector512 vx = x.AsUInt32() - Vector512.Create(V_OFF); + Vector512 n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF); + + Vector512 r = vx.AsSingle() - Vector512.Create(1.0f); + + Vector512 r2 = r * r; + Vector512 r4 = r2 * r2; + Vector512 r8 = r4 * r4; + + Vector512 poly = (Vector512.Create(C9) * r + Vector512.Create(C8)) * r8 + + (((Vector512.Create(C7) * r + Vector512.Create(C6)) * r2 + + (Vector512.Create(C5) * r + Vector512.Create(C4))) * r4 + + ((Vector512.Create(C3) * r + Vector512.Create(C2)) * r2 + + (Vector512.Create(C1) * r + Vector512.Create(C0)))); + + return Vector512.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); } #endif } - private readonly struct DivideOperator : IBinaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) { - public static float Invoke(float x, float y) => x / y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; -#endif + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); - public static float Invoke(Vector128 x) => throw new NotSupportedException(); - public static float Invoke(Vector256 x) => throw new NotSupportedException(); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => throw new NotSupportedException(); -#endif + return Vector128.ConditionalSelect(mask, left, right); } - private readonly struct NegateOperator : IUnaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) { - public static float Invoke(float x) => -x; - public static Vector128 Invoke(Vector128 x) => -x; - public static Vector256 Invoke(Vector256 x) => -x; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) => -x; -#endif - } + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); - private readonly struct AddMultiplyOperator : ITernaryOperator - { - public static float Invoke(float x, float y, float z) => (x + y) * z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; -#endif + return Vector128.ConditionalSelect(mask, left, right); } - private readonly struct MultiplyAddOperator : ITernaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) { - public static float Invoke(float x, float y, float z) => (x * y) + z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; -#endif + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); } - private readonly struct IdentityOperator : IUnaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) { - public static float Invoke(float x) => x; - public static Vector128 Invoke(Vector128 x) => x; - public static Vector256 Invoke(Vector256 x) => x; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) => x; -#endif + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); } - private readonly struct SquaredOperator : IUnaryOperator - { - public static float Invoke(float x) => x * x; - public static Vector128 Invoke(Vector128 x) => x * x; - public static Vector256 Invoke(Vector256 x) => x * x; #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) => x * x; -#endif + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + { + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); + + return Vector512.ConditionalSelect(mask, left, right); } - private readonly struct AbsoluteOperator : IUnaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) { - public static float Invoke(float x) => MathF.Abs(x); + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); - public static Vector128 Invoke(Vector128 x) - { - Vector128 raw = x.AsUInt32(); - Vector128 mask = Vector128.Create((uint)0x7FFFFFFF); - return (raw & mask).AsSingle(); - } - - public static Vector256 Invoke(Vector256 x) - { - Vector256 raw = x.AsUInt32(); - Vector256 mask = Vector256.Create((uint)0x7FFFFFFF); - return (raw & mask).AsSingle(); - } + return Vector512.ConditionalSelect(mask, left, right); + } +#endif + /// 1f / (1f + MathF.Exp(-x)) + private readonly struct SigmoidOperator : IUnaryOperator + { + public static float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); + public static Vector128 Invoke(Vector128 x) => Vector128.Create(1f) / (Vector128.Create(1f) + ExpOperator.Invoke(-x)); + public static Vector256 Invoke(Vector256 x) => Vector256.Create(1f) / (Vector256.Create(1f) + ExpOperator.Invoke(-x)); #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) - { - Vector512 raw = x.AsUInt32(); - Vector512 mask = Vector512.Create((uint)0x7FFFFFFF); - return (raw & mask).AsSingle(); - } + public static Vector512 Invoke(Vector512 x) => Vector512.Create(1f) / (Vector512.Create(1f) + ExpOperator.Invoke(-x)); #endif } + /// Operator that takes one input value and returns a single value. private interface IUnaryOperator { static abstract float Invoke(float x); @@ -1248,20 +11570,30 @@ private interface IUnaryOperator #endif } + /// Operator that takes two input values and returns a single value. private interface IBinaryOperator { static abstract float Invoke(float x, float y); - static abstract Vector128 Invoke(Vector128 x, Vector128 y); - static abstract float Invoke(Vector128 x); static abstract Vector256 Invoke(Vector256 x, Vector256 y); - static abstract float Invoke(Vector256 x); #if NET8_0_OR_GREATER static abstract Vector512 Invoke(Vector512 x, Vector512 y); +#endif + } + + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator + { + static abstract float Invoke(Vector128 x); + static abstract float Invoke(Vector256 x); +#if NET8_0_OR_GREATER static abstract float Invoke(Vector512 x); #endif + + static virtual float IdentityValue => throw new NotSupportedException(); } + /// Operator that takes three input values and returns a single value. private interface ITernaryOperator { static abstract float Invoke(float x, float y, float z); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index ed8b3aea0d560..c0039be0a08e2 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -8,14 +9,8 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { - private static unsafe bool IsNegative(float f) => *(int*)&f < 0; - - private static float MaxMagnitude(float x, float y) => MathF.Abs(x) >= MathF.Abs(y) ? x : y; - - private static float MinMagnitude(float x, float y) => MathF.Abs(x) < MathF.Abs(y) ? x : y; - - private static float Log2(float x) => MathF.Log(x, 2); - + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// Assumes arguments have already been validated to be non-empty and equal length. private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) { // Compute the same as: @@ -26,9 +21,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector.Count) + if (Vector.IsHardwareAccelerated && + Vector.Count <= 16 && // currently never greater than 8, but 16 would occur if/when AVX512 is supported, and logic in remainder handling assumes that maximum + x.Length >= Vector.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); @@ -39,6 +34,7 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan.Count; + int i = 0; do { Vector xVec = AsVector(ref xRef, i); @@ -52,6 +48,21 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = AsVector(ref xRef, x.Length - Vector.Count); + Vector yVec = AsVector(ref yRef, x.Length - Vector.Count); + + Vector remainderMask = CreateRemainderMaskSingleVector(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector += xVec * yVec; + xSumOfSquaresVector += xVec * xVec; + ySumOfSquaresVector += yVec * yVec; + } + // Sum the vector lanes into the scalar result. for (int e = 0; e < Vector.Count; e++) { @@ -60,539 +71,3464 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan( - float identityValue, ReadOnlySpan x, TLoad load = default, TAggregate aggregate = default) - where TLoad : struct, IUnaryOperator - where TAggregate : struct, IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + /// Performs an aggregation over all elements in to produce a single-precision floating-point value. + /// Specifies the transform operation that should be applied to each element loaded from . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied after the transform is applied to each element. + /// + private static unsafe float Aggregate( + ReadOnlySpan x, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && transformOp.CanVectorize) + { + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = transformOp.Invoke(AsVector(ref xRef)); + Vector end = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static unsafe float Aggregate( + ReadOnlySpan x, ReadOnlySpan y, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = binaryOp.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + } + + /// + /// This is the same as + /// with an identity transform, except it early exits on NaN. + /// + private static float MinMaxCore(ReadOnlySpan x, TMinMaxOperator op = default) + where TMinMaxOperator : struct, IBinaryOperator + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + float result = x[0]; + int i = 0; + + if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + resultVector = op.Invoke(resultVector, current); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + resultVector = op.Invoke(resultVector, current); + } + + // Aggregate the lanes in the vector to create the final scalar result. + for (int f = 0; f < Vector.Count; f++) + { + result = op.Invoke(result, resultVector[f]); + } + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + for (; (uint)i < (uint)x.Length; i++) + { + float current = x[i]; + + if (float.IsNaN(current)) + { + return current; + } + + result = op.Invoke(result, current); + } + + return result; + } + + private static readonly int[] s_0through7 = [0, 1, 2, 3, 4, 5, 6, 7]; + + private static int IndexOfMinMaxCore(ReadOnlySpan x, TIndexOfMinMaxOperator op = default) + where TIndexOfMinMaxOperator : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + int result; + int i = 0; + + if (Vector.IsHardwareAccelerated && Vector.Count <= 8 && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + Vector resultIndex = new Vector(s_0through7); + Vector curIndex = resultIndex; + Vector increment = new Vector(Vector.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + curIndex = Vector.Add(curIndex, increment); + + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex = Vector.Add(curIndex, new Vector(x.Length - i)); + + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + } + + result = op.Invoke(resultVector, resultIndex); + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + float curResult = x[i]; + int curIn = i; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = op.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + /// Performs an element-wise operation on and writes the results to . + /// Specifies the operation to perform on each element loaded from . + private static unsafe void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination, TUnaryOperator op = default) + where TUnaryOperator : struct, IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && op.CanVectorize) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length, TUnaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and . + /// + private static unsafe void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length, TBinaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination, default, op); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static unsafe void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && xTransformOp.CanVectorize) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i))), + y); + } + } + + static void Vectorized(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef)), + yVec); + Vector end = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))), + yVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 6)), + y); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 5)), + y); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 4)), + y); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 3)), + y); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = binaryOp.Invoke(xTransformOp.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, zRef); + break; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static unsafe void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + z); + } + } + + static void Vectorized(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector zVec = new Vector(z); + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + zVec); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + zVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + z); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + z); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + z); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + z); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static unsafe void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + y, + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = op.Invoke(AsVector(ref xRef), + yVec, + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + y, + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + y, + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + y, + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + y, + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// Loads a from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start) => + ref Unsafe.As>(ref start); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, nuint offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, (nint)(offset))); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref int start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => !IsNegative(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector IsPositive(Vector vector) => + ((Vector)Vector.GreaterThan(((Vector)vector), Vector.Zero)); + + /// Gets whether the specified is negative. + private static unsafe bool IsNegative(float f) => *(int*)&f < 0; + + /// Gets whether each specified is negative. + private static Vector IsNegative(Vector f) => + (Vector)Vector.LessThan((Vector)f, Vector.Zero); + + /// Gets the base 2 logarithm of . + private static float Log2(float x) => MathF.Log(x, 2); + + /// + /// Gets a vector mask that will be all-ones-set for the first elements + /// and zero for all other elements. + /// + private static Vector CreateAlignmentMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (count * 16)); + } + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + private static Vector CreateRemainderMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (count * 16) + (16 - Vector.Count)); + } + + /// x + y + private readonly struct AddOperator : IAggregationOperator + { + public float Invoke(float x, float y) => x + y; + public Vector Invoke(Vector x, Vector y) => x + y; + public float IdentityValue => 0; + } + + /// x - y + private readonly struct SubtractOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x - y; + public Vector Invoke(Vector x, Vector y) => x - y; + } + + /// (x - y) * (x - y) + private readonly struct SubtractSquaredOperator : IBinaryOperator + { + public float Invoke(float x, float y) + { + float tmp = x - y; + return tmp * tmp; + } + + public Vector Invoke(Vector x, Vector y) + { + Vector tmp = x - y; + return tmp * tmp; + } + } + + /// x * y + private readonly struct MultiplyOperator : IAggregationOperator + { + public float Invoke(float x, float y) => x * y; + public Vector Invoke(Vector x, Vector y) => x * y; + public float IdentityValue => 1; + } + + /// x / y + private readonly struct DivideOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x / y; + public Vector Invoke(Vector x, Vector y) => x / y; + } + + private interface IIndexOfOperator + { + int Invoke(ref float result, float current, int resultIndex, int curIndex); + int Invoke(Vector result, Vector resultIndex); + void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex); + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMax = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (result[i] == curMax && IsNegative(curMax) && !IsNegative(result[i])) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + else if (result[i] > curMax) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - ref float xRef = ref MemoryMarshal.GetReference(x); + Vector lessThanMask = Vector.GreaterThan(result, current); - // Load the first vector as the initial set of results - Vector resultVector = load.Invoke(AsVector(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector.Count; + Vector equalMask = Vector.Equals(result, current); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector.Count; - do + if (equalMask != Vector.Zero) { - resultVector = aggregate.Invoke(resultVector, load.Invoke(AsVector(ref xRef, i))); - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); - // Aggregate the lanes in the vector back into the scalar result - for (int f = 0; f < Vector.Count; f++) - { - result = aggregate.Invoke(result, resultVector[f]); + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); } - } - - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = aggregate.Invoke(result, load.Invoke(x[i])); - } - return result; - } + result = Vector.ConditionalSelect(lessThanMask, result, current); - private static float Aggregate( - float identityValue, ReadOnlySpan x, ReadOnlySpan y, TBinary binary = default, TAggregate aggregate = default) - where TBinary : struct, IBinaryOperator - where TAggregate : struct, IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - - // Load the first vector as the initial set of results - Vector resultVector = binary.Invoke(AsVector(ref xRef, 0), AsVector(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector.Count; - - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector.Count; - do + if (result == current) { - resultVector = aggregate.Invoke(resultVector, binary.Invoke(AsVector(ref xRef, i), AsVector(ref yRef, i))); - i += Vector.Count; + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } } - while (i <= oneVectorFromEnd); - - // Aggregate the lanes in the vector back into the scalar result - for (int f = 0; f < Vector.Count; f++) + else if (current > result) { - result = aggregate.Invoke(result, resultVector[f]); + result = current; + return curIndex; } - } - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = aggregate.Invoke(result, binary.Invoke(x[i], y[i])); + return resultIndex; } - - return result; } - private static void InvokeSpanIntoSpan( - ReadOnlySpan x, Span destination, TUnaryOperator op = default) - where TUnaryOperator : struct, IUnaryOperator + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator { - if (x.Length > destination.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } - if (Vector.IsHardwareAccelerated) + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector maxIndex) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + float curMax = result[0]; + int curIn = maxIndex[0]; + for (int i = 1; i < Vector.Count; i++) { - // Loop handling one vector at a time. - do + if (MathF.Abs(result[i]) == MathF.Abs(curMax) && IsNegative(curMax) && !IsNegative(result[i])) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i)); - - i += Vector.Count; + curMax = result[i]; + curIn = maxIndex[i]; } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) + else if (MathF.Abs(result[i]) > MathF.Abs(curMax)) { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex)); + curMax = result[i]; + curIn = maxIndex[i]; } - - return; } - } - - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i)); - i++; + return curIn; } - } - private static void InvokeSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) - where TBinaryOperator : struct, IBinaryOperator - { - if (x.Length != y.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + Vector maxMag = Vector.Abs(result), currentMag = Vector.Abs(current); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector lessThanMask = Vector.GreaterThan(maxMag, currentMag); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + Vector equalMask = Vector.Equals(result, current); - if (Vector.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (equalMask != Vector.Zero) { - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i)); - - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex)); - } - - return; + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); } - } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + result = Vector.ConditionalSelect(lessThanMask, result, current); - i++; + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); } } - private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) - where TBinaryOperator : struct, IBinaryOperator + private readonly struct IndexOfMinOperator : IIndexOfOperator { - if (x.Length > destination.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + return resultIndex; + } - if (Vector.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) { - // Loop handling one vector at a time. - Vector yVec = new(y); - do + if (result[i] == curMin && IsPositive(curMin) && !IsPositive(result[i])) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - yVec); - - i += Vector.Count; + curMin = result[i]; + curIn = resultIndex[i]; } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) + else if (result[i] < curMin) { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec); + curMin = result[i]; + curIn = resultIndex[i]; } - - return; } + + return curIn; } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - y); + Vector lessThanMask = Vector.LessThan(result, current); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); - i++; + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); } } - private static void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator { - if (x.Length != y.Length || x.Length != z.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return resultIndex; } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; - - if (Vector.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) { - // Loop handling one vector at a time. - do + if (MathF.Abs(result[i]) == MathF.Abs(curMin) && IsPositive(curMin) && !IsPositive(result[i])) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - AsVector(ref zRef, i)); - - i += Vector.Count; + curMin = result[i]; + curIn = resultIndex[i]; } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) + else if (MathF.Abs(result[i]) < MathF.Abs(curMin)) { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - AsVector(ref zRef, lastVectorIndex)); + curMin = result[i]; + curIn = resultIndex[i]; } - - return; } - } - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); - - i++; + return curIn; } - } - private static void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != y.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + Vector minMag = Vector.Abs(result), currentMag = Vector.Abs(current); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector lessThanMask = Vector.LessThan(minMag, currentMag); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + Vector equalMask = Vector.Equals(result, current); - if (Vector.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (equalMask != Vector.Zero) { - Vector zVec = new(z); - - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - zVec); + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - zVec); - } + result = Vector.ConditionalSelect(lessThanMask, result, current); - return; - } + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); } + } - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + /// MathF.Max(x, y) (but without guaranteed NaN propagation) + private readonly struct MaxOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => + x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.Max(x, y)); + } - i++; - } + /// MathF.Max(x, y) + private readonly struct MaxPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => MathF.Max(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.Max(x, y)), + y), + x); } - private static void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator + /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) + private readonly struct MaxMagnitudeOperator : IBinaryOperator { - if (x.Length != z.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + yMag == xMag ? + (IsNegative(x) ? y : x) : + (xMag > yMag ? x : y); } - if (x.Length > destination.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(xMag, yMag), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.ConditionalSelect(Vector.GreaterThan(xMag, yMag), x, y)); } + } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; - - if (Vector.IsHardwareAccelerated) + /// Operator to get x or y based on which has the larger MathF.Abs + private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) - { - Vector yVec = new(y); - - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - yVec, - AsVector(ref zRef, i)); - - i += Vector.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec, - AsVector(ref zRef, lastVectorIndex)); - } - - return; - } + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag > yMag || float.IsNaN(xMag) || (xMag == yMag && !IsNegative(x)) ? x : y; } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); - - i++; + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(xMag, yMag), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.ConditionalSelect(Vector.GreaterThan(xMag, yMag), x, y)), + y), + x); } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static ref Vector AsVector(ref float start, int offset) => - ref Unsafe.As>( - ref Unsafe.Add(ref start, offset)); - - private readonly struct AddOperator : IBinaryOperator + /// MathF.Min(x, y) (but NaNs may not be propagated) + private readonly struct MinOperator : IBinaryOperator { - public float Invoke(float x, float y) => x + y; - public Vector Invoke(Vector x, Vector y) => x + y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => + x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(y), y, x), + Vector.Min(x, y)); } - private readonly struct SubtractOperator : IBinaryOperator + /// MathF.Min(x, y) + private readonly struct MinPropagateNaNOperator : IBinaryOperator { - public float Invoke(float x, float y) => x - y; - public Vector Invoke(Vector x, Vector y) => x - y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => MathF.Min(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), x, y), + Vector.Min(x, y)), + y), + x); } - private readonly struct SubtractSquaredOperator : IBinaryOperator + /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) + private readonly struct MinMagnitudeOperator : IBinaryOperator { + [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) { - float tmp = x - y; - return tmp * tmp; + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + yMag == xMag ? + (IsNegative(y) ? y : x) : + (yMag < xMag ? y : x); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public Vector Invoke(Vector x, Vector y) { - Vector tmp = x - y; - return tmp * tmp; + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(yMag, xMag), + Vector.ConditionalSelect(IsNegative(y), y, x), + Vector.ConditionalSelect(Vector.LessThan(yMag, xMag), y, x)); } } - private readonly struct MultiplyOperator : IBinaryOperator + /// Operator to get x or y based on which has the smaller MathF.Abs + private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator { - public float Invoke(float x, float y) => x * y; - public Vector Invoke(Vector x, Vector y) => x * y; - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag < yMag || float.IsNaN(xMag) || (xMag == yMag && IsNegative(x)) ? x : y; + } - private readonly struct DivideOperator : IBinaryOperator - { - public float Invoke(float x, float y) => x / y; - public Vector Invoke(Vector x, Vector y) => x / y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) + { + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + + return + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(yMag, xMag), + Vector.ConditionalSelect(IsNegative(x), x, y), + Vector.ConditionalSelect(Vector.LessThan(xMag, yMag), x, y)), + y), + x); + } } + /// -x private readonly struct NegateOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => -x; public Vector Invoke(Vector x) => -x; } + /// (x + y) * z private readonly struct AddMultiplyOperator : ITernaryOperator { public float Invoke(float x, float y, float z) => (x + y) * z; public Vector Invoke(Vector x, Vector y, Vector z) => (x + y) * z; } + /// (x * y) + z private readonly struct MultiplyAddOperator : ITernaryOperator { public float Invoke(float x, float y, float z) => (x * y) + z; public Vector Invoke(Vector x, Vector y, Vector z) => (x * y) + z; } + /// x private readonly struct IdentityOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => x; public Vector Invoke(Vector x) => x; } + /// x * x private readonly struct SquaredOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => x * x; public Vector Invoke(Vector x) => x * x; } + /// MathF.Abs(x) private readonly struct AbsoluteOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => MathF.Abs(x); + public Vector Invoke(Vector x) => Vector.Abs(x); + } - public Vector Invoke(Vector x) - { - Vector raw = Vector.AsVectorUInt32(x); - Vector mask = new Vector(0x7FFFFFFF); - return Vector.AsVectorSingle(raw & mask); - } + /// MathF.Exp(x) + private readonly struct ExpOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Exp(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Sinh(x) + private readonly struct SinhOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Sinh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Cosh(x) + private readonly struct CoshOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Cosh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Tanh(x) + private readonly struct TanhOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Tanh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Log(x) + private readonly struct LogOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Log(x); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Log2(x) + private readonly struct Log2Operator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => Log2(x); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// 1f / (1f + MathF.Exp(-x)) + private readonly struct SigmoidOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); } + /// Operator that takes one input value and returns a single value. private interface IUnaryOperator { + bool CanVectorize { get; } float Invoke(float x); Vector Invoke(Vector x); } + /// Operator that takes two input values and returns a single value. private interface IBinaryOperator { float Invoke(float x, float y); Vector Invoke(Vector x, Vector y); } + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator + { + float IdentityValue { get; } + } + + /// Operator that takes three input values and returns a single value. private interface ITernaryOperator { float Invoke(float x, float y, float z); diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index 902b27787e856..272991aed44ab 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -18,5 +18,9 @@ public static void ThrowArgument_SpansMustHaveSameLength() => [DoesNotReturn] public static void ThrowArgument_SpansMustBeNonEmpty() => throw new ArgumentException(SR.Argument_SpansMustBeNonEmpty); + + [DoesNotReturn] + public static void ThrowArgument_InputAndDestinationSpanMustNotOverlap() => + throw new ArgumentException(SR.Argument_InputAndDestinationSpanMustNotOverlap, "destination"); } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 777ab49609856..09aa13ae35800 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Runtime.InteropServices; using Xunit; +using Xunit.Sdk; #pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 @@ -13,15 +14,25 @@ namespace System.Numerics.Tensors.Tests { public static partial class TensorPrimitivesTests { - private const double Tolerance = 0.0001; - + #region Test Utilities public static IEnumerable TensorLengthsIncluding0 => TensorLengths.Concat(new object[][] { [0] }); public static IEnumerable TensorLengths => - from length in Enumerable.Range(1, 128) + from length in Enumerable.Range(1, 256) select new object[] { length }; + public static IEnumerable VectorLengthAndIteratedRange(float min, float max, float increment) + { + foreach (int length in new[] { 4, 8, 16 }) + { + for (float f = min; f <= max; f += increment) + { + yield return new object[] { length, f }; + } + } + } + private static readonly Random s_random = new Random(20230828); private static BoundedMemory CreateTensor(int size) => BoundedMemory.Allocate(size); @@ -41,300 +52,300 @@ private static void FillTensor(Span tensor) } } - private static float NextSingle() - { + private static float NextSingle() => // For testing purposes, get a mix of negative and positive values. - return (float)((s_random.NextDouble() * 2) - 1); - } + (float)((s_random.NextDouble() * 2) - 1); - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddTwoTensors(int tensorLength) + private static void AssertEqualTolerance(double expected, double actual, double tolerance = 0.00001f) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Add(x, y, destination); - - for (int i = 0; i < tensorLength; i++) + double diff = Math.Abs(expected - actual); + if (diff > tolerance && + diff > Math.Max(Math.Abs(expected), Math.Abs(actual)) * tolerance) { - Assert.Equal(x[i] + y[i], destination[i]); + throw new EqualException(expected, actual); } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + private static unsafe float MathFMaxMagnitude(float x, float y) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); - - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax > ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x >= 0) ? x : y; } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensors_ThrowsForTooShortDestination(int tensorLength) + private static unsafe float MathFMinMagnitude(float x, float y) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax < ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x < 0) ? x : y; } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddTensorAndScalar(int tensorLength) + private static unsafe float UInt32ToSingle(uint i) => *(float*)&i; + + private static unsafe float SingleToUInt32(float f) => *(uint*)&f; + + /// Gets a variety of special values (e.g. NaN). + private static IEnumerable GetSpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + // NaN + yield return UInt32ToSingle(0xFFC0_0000); // -qNaN / float.NaN + yield return UInt32ToSingle(0xFFFF_FFFF); // -qNaN / all-bits-set + yield return UInt32ToSingle(0x7FC0_0000); // +qNaN + yield return UInt32ToSingle(0xFFA0_0000); // -sNaN + yield return UInt32ToSingle(0x7FA0_0000); // +sNaN - TensorPrimitives.Add(x, y, destination); + // +Infinity, -Infinity + yield return float.PositiveInfinity; + yield return float.NegativeInfinity; - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(x[i] + y, destination[i]); - } - } + // +Zero, -Zero + yield return +0.0f; + yield return -0.0f; - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + // Subnormals + yield return +float.Epsilon; + yield return -float.Epsilon; + yield return UInt32ToSingle(0x007F_FFFF); + yield return UInt32ToSingle(0x807F_FFFF); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + // Normals + yield return UInt32ToSingle(0x0080_0000); + yield return UInt32ToSingle(0x8080_0000); + yield return UInt32ToSingle(0x7F7F_FFFF); // MaxValue + yield return UInt32ToSingle(0xFF7F_FFFF); // MinValue } - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void SubtractTwoTensors(int tensorLength) + /// + /// Runs the specified action for each special value. Before the action is invoked, + /// the value is stored into a random position in , and the original + /// value is subsequently restored. + /// + private static void RunForEachSpecialValue(Action action, BoundedMemory x) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + foreach (float value in GetSpecialValues()) + { + int pos = s_random.Next(x.Length); + float orig = x[pos]; + x[pos] = value; - TensorPrimitives.Subtract(x, y, destination); + action(); - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(x[i] - y[i], destination[i]); + x[pos] = orig; } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + /// + /// Loads a variety of special values (e.g. NaN) into random positions in + /// and related values into the corresponding positions in . + /// + private static void SetSpecialValues(Span x, Span y) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + int pos; - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); - } + // NaNs + pos = s_random.Next(x.Length); + x[pos] = float.NaN; + y[pos] = UInt32ToSingle(0x7FC0_0000); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTwoTensors_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + // +Infinity, -Infinity + pos = s_random.Next(x.Length); + x[pos] = float.PositiveInfinity; + y[pos] = float.NegativeInfinity; - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + // +Zero, -Zero + pos = s_random.Next(x.Length); + x[pos] = +0.0f; + y[pos] = -0.0f; + + // +Epsilon, -Epsilon + pos = s_random.Next(x.Length); + x[pos] = +float.Epsilon; + y[pos] = -float.Epsilon; + + // Same magnitude, opposite sign + pos = s_random.Next(x.Length); + x[pos] = +5.0f; + y[pos] = -5.0f; } + #endregion + #region Abs [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void SubtractTensorAndScalar(int tensorLength) + public static void Abs(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Subtract(x, y, destination); + TensorPrimitives.Abs(x, destination); - for (int i = 0; i < tensorLength; i++) + for (int i = 0; i < x.Length; i++) { - Assert.Equal(x[i] - y, destination[i]); + AssertEqualTolerance(MathF.Abs(x[i]), destination[i]); } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); - } - [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyTwoTensors(int tensorLength) + public static void Abs_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - TensorPrimitives.Multiply(x, y, destination); + TensorPrimitives.Abs(x, x); - for (int i = 0; i < tensorLength; i++) + for (int i = 0; i < x.Length; i++) { - Assert.Equal(x[i] * y[i], destination[i]); + AssertEqualTolerance(MathF.Abs(xOrig[i]), x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + public static void Abs_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(x, destination)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public static void Abs_ThrowsForOverlapppingInputsWithOutputs() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(0, 5))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(2, 5))); } + #endregion + #region Add [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyTensorAndScalar(int tensorLength) + public static void Add_TwoTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); + using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Multiply(x, y, destination); - + TensorPrimitives.Add(x, y, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] * y, destination[i]); + AssertEqualTolerance(x[i] + y[i], destination[i]); } - } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + float[] xOrig = x.Span.ToArray(); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + // Validate that the destination can be the same as an input. + TensorPrimitives.Add(x, x, x); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); + } } [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void DivideTwoTensors(int tensorLength) + public static void Add_TwoTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - TensorPrimitives.Divide(x, y, destination); + TensorPrimitives.Add(x, x, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] / y[i], destination[i]); + AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void DivideTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + public static void Add_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Add(y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void DivideTwoTensors_ThrowsForTooShortDestination(int tensorLength) + public static void Add_TwoTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + } + + [Fact] + public static void Add_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(4, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(6, 2))); } [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void DivideTensorAndScalar(int tensorLength) + public static void Add_TensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Divide(x, y, destination); + TensorPrimitives.Add(x, y, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] / y, destination[i]); + AssertEqualTolerance(x[i] + y, destination[i]); } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void DivideTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); - } - [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void NegateTensor(int tensorLength) + public static void Add_TensorScalar_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); - TensorPrimitives.Negate(x, destination); + TensorPrimitives.Add(x, y, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(-x[i], destination[i]); + AssertEqualTolerance(xOrig[i] + y, x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void NegateTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + } + + [Fact] + public static void Add_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); } + #endregion + #region AddMultiply [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) + public static void AddMultiply_ThreeTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -345,37 +356,42 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]); + AssertEqualTolerance((x[i] + y[i]) * multiplier[i], destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_ThreeTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + TensorPrimitives.AddMultiply(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + xOrig[i]) * xOrig[i], x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) + public static void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, z, y, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(z, x, y, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) + public static void AddMultiply_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -385,9 +401,21 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDest AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) + public static void AddMultiply_TensorTensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -398,13 +426,29 @@ public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y[i]) * multiplier, destination[i]); + AssertEqualTolerance((x[i] + y[i]) * multiplier, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float multiplier = NextSingle(); + + TensorPrimitives.AddMultiply(x, x, multiplier, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + xOrig[i]) * multiplier, x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) + public static void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); @@ -412,11 +456,12 @@ public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForMismatchedLengths using BoundedMemory destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(y, x, multiplier, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination(int tensorLength) + public static void AddMultiply_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -426,9 +471,19 @@ public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestinati AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) + public static void AddMultiply_TensorScalarTensor(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); @@ -439,25 +494,42 @@ public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y) * multiplier[i], destination[i]); + AssertEqualTolerance((x[i] + y) * multiplier[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.AddMultiply(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + y) * xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) + public static void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(z, y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination(int tensorLength) + public static void AddMultiply_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); @@ -467,355 +539,315 @@ public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDest AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Cosh [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength) + public static void Cosh(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + TensorPrimitives.Cosh(x, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] * y[i]) + addend[i], destination[i]); + AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Cosh_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); - } + float[] xOrig = x.Span.ToArray(); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + TensorPrimitives.Cosh(x, x); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(xOrig[i]), x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Cosh_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + RunForEachSpecialValue(() => + { + TensorPrimitives.Cosh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); + } + }, x); } [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength) + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] + public static void Cosh_ValueRange(int vectorLength, float element) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float addend = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] x = new float[vectorLength]; + float[] dest = new float[vectorLength]; - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + x.AsSpan().Fill(element); + TensorPrimitives.Cosh(x, dest); - for (int i = 0; i < tensorLength; i++) + float expected = MathF.Cosh(element); + foreach (float actual in dest) { - Assert.Equal((x[i] * y[i]) + addend, destination[i]); + AssertEqualTolerance(expected, actual); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestination(int tensorLength) + public static void Cosh_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float addend = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); } + [Fact] + public static void Cosh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region CosineSimilarity [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength) + [MemberData(nameof(TensorLengths))] + public static void CosineSimilarity_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal((x[i] * y) + addend[i], destination[i]); - } + Assert.Throws(() => TensorPrimitives.CosineSimilarity(x, y)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(y, x)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public static void CosineSimilarity_ThrowsForEmpty() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); } [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void ExpTensor(int tensorLength) + [InlineData(new float[] { 3, 2, 0, 5 }, new float[] { 1, 0, 0, 0 }, 0.48666f)] + [InlineData(new float[] { 1, 1, 1, 1, 1, 0 }, new float[] { 1, 1, 1, 1, 0, 1 }, 0.80f)] + public static void CosineSimilarity_KnownValues(float[] x, float[] y, float expectedResult) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Exp(x, destination); - - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(MathF.Exp(x[i]), destination[i]); - } + AssertEqualTolerance(expectedResult, TensorPrimitives.CosineSimilarity(x, y)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void ExpTensor_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); - } - - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void LogTensor(int tensorLength) + public static void CosineSimilarity(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Log(x, destination); + using BoundedMemory y = CreateAndFillTensor(tensorLength); - for (int i = 0; i < tensorLength; i++) + float dot = 0f, squareX = 0f, squareY = 0f; + for (int i = 0; i < x.Length; i++) { - Assert.Equal(MathF.Log(x[i]), destination[i]); + dot += x[i] * y[i]; + squareX += x[i] * x[i]; + squareY += y[i] * y[i]; } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void LogTensor_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); + AssertEqualTolerance(dot / (MathF.Sqrt(squareX) * MathF.Sqrt(squareY)), TensorPrimitives.CosineSimilarity(x, y)); } + #endregion - [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Log2(int tensorLength) + #region Distance + [Fact] + public static void Distance_ThrowsForEmpty() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Log2(x, destination); - - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(MathF.Log(x[i], 2), destination[i], Tolerance); - } + Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.Distance(CreateTensor(1), ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Log2_ThrowsForTooShortDestination(int tensorLength) + public static void Distance_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(x, destination)); + Assert.Throws(() => TensorPrimitives.Distance(x, y)); + Assert.Throws(() => TensorPrimitives.Distance(y, x)); } [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void CoshTensor(int tensorLength) + [InlineData(new float[] { 3, 2 }, new float[] { 4, 1 }, 1.4142f)] + [InlineData(new float[] { 0, 4 }, new float[] { 6, 2 }, 6.3245f)] + [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 5.19615f)] + [InlineData(new float[] { 5, 1, 6, 10 }, new float[] { 7, 2, 8, 4 }, 6.7082f)] + public static void Distance_KnownValues(float[] x, float[] y, float expectedResult) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Cosh(x, destination); - - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(MathF.Cosh(x[i]), destination[i]); - } + AssertEqualTolerance(expectedResult, TensorPrimitives.Distance(x, y)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void CoshTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Distance(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory y = CreateAndFillTensor(tensorLength); - AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); + float distance = 0f; + for (int i = 0; i < x.Length; i++) + { + distance += (x[i] - y[i]) * (x[i] - y[i]); + } + + AssertEqualTolerance(MathF.Sqrt(distance), TensorPrimitives.Distance(x, y)); } + #endregion + #region Divide [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void SinhTensor(int tensorLength) + public static void Divide_TwoTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Sinh(x, destination); + TensorPrimitives.Divide(x, y, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Sinh(x[i]), destination[i]); + AssertEqualTolerance(x[i] / y[i], destination[i]); } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SinhTensor_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); - } - [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void TanhTensor(int tensorLength) + public static void Divide_TwoTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - TensorPrimitives.Tanh(x, destination); + TensorPrimitives.Divide(x, x, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Tanh(x[i]), destination[i]); + AssertEqualTolerance(xOrig[i] / xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void TanhTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Divide_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); + Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Divide(y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void CosineSimilarity_ThrowsForMismatchedLengths_x_y(int tensorLength) + public static void Divide_TwoTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(x, y)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } [Fact] - public static void CosineSimilarity_ThrowsForEmpty_x_y() - { - Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); - } - - [Theory] - [InlineData(new float[] { 3, 2, 0, 5 }, new float[] { 1, 0, 0, 0 }, 0.48666f)] - [InlineData(new float[] { 1, 1, 1, 1, 1, 0 }, new float[] { 1, 1, 1, 1, 0, 1 }, 0.80f)] - public static void CosineSimilarity_KnownValues(float[] x, float[] y, float expectedResult) + public static void Divide_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(expectedResult, TensorPrimitives.CosineSimilarity(x, y), Tolerance); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void CosineSimilarity(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); - float dot = 0f, squareX = 0f, squareY = 0f; - for (int i = 0; i < x.Length; i++) + TensorPrimitives.Divide(x, y, destination); + + for (int i = 0; i < tensorLength; i++) { - dot += x[i] * y[i]; - squareX += x[i] * x[i]; - squareY += y[i] * y[i]; + AssertEqualTolerance(x[i] / y, destination[i]); } - - Assert.Equal(dot / (Math.Sqrt(squareX) * Math.Sqrt(squareY)), TensorPrimitives.CosineSimilarity(x, y), Tolerance); - } - - [Fact] - public static void Distance_ThrowsForEmpty_x_y() - { - Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.Distance(CreateTensor(1), ReadOnlySpan.Empty)); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Distance_ThrowsForMismatchedLengths_x_y(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); - Assert.Throws(() => TensorPrimitives.Distance(x, y)); - } + TensorPrimitives.Divide(x, y, x); - [Theory] - [InlineData(new float[] { 3, 2 }, new float[] { 4, 1 }, 1.4142f)] - [InlineData(new float[] { 0, 4 }, new float[] { 6, 2 }, 6.3245f)] - [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 5.1961f)] - [InlineData(new float[] { 5, 1, 6, 10 }, new float[] { 7, 2, 8, 4 }, 6.7082f)] - public static void Distance_KnownValues(float[] x, float[] y, float expectedResult) - { - Assert.Equal(expectedResult, TensorPrimitives.Distance(x, y), Tolerance); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] / y, x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void Distance(int tensorLength) + public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - float distance = 0f; - for (int i = 0; i < x.Length; i++) - { - distance += (x[i] - y[i]) * (x[i] - y[i]); - } + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + } - Assert.Equal(Math.Sqrt(distance), TensorPrimitives.Distance(x, y), Tolerance); + [Fact] + public static void Divide_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } + #endregion + #region Dot [Theory] [MemberData(nameof(TensorLengths))] public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) @@ -824,6 +856,7 @@ public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Dot(x, y)); + Assert.Throws(() => TensorPrimitives.Dot(y, x)); } [Theory] @@ -833,7 +866,7 @@ public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) [InlineData(new float[] { }, new float[] { }, 0)] public static void Dot_KnownValues(float[] x, float[] y, float expectedResult) { - Assert.Equal(expectedResult, TensorPrimitives.Dot(x, y), Tolerance); + AssertEqualTolerance(expectedResult, TensorPrimitives.Dot(x, y)); } [Theory] @@ -849,171 +882,166 @@ public static void Dot(int tensorLength) dot += x[i] * y[i]; } - Assert.Equal(dot, TensorPrimitives.Dot(x, y), Tolerance); + AssertEqualTolerance(dot, TensorPrimitives.Dot(x, y)); } + #endregion + #region Exp [Theory] - [InlineData(new float[] { 1, 2, 3 }, 3.7416575f)] - [InlineData(new float[] { 3, 4 }, 5)] - [InlineData(new float[] { 3 }, 3)] - [InlineData(new float[] { 3, 4, 1, 2 }, 5.477226)] - [InlineData(new float[] { }, 0f)] - public static void Norm_KnownValues(float[] x, float expectedResult) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Exp(int tensorLength) { - Assert.Equal(expectedResult, TensorPrimitives.Norm(x), Tolerance); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Exp(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); + } } [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void Norm(int tensorLength) + public static void Exp_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - float sumOfSquares = 0f; - for (int i = 0; i < x.Length; i++) + TensorPrimitives.Exp(x, x); + + for (int i = 0; i < tensorLength; i++) { - sumOfSquares += x[i] * x[i]; + AssertEqualTolerance(MathF.Exp(xOrig[i]), x[i]); } - - Assert.Equal(Math.Sqrt(sumOfSquares), TensorPrimitives.Norm(x), Tolerance); } [Theory] [MemberData(nameof(TensorLengths))] - public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) + public static void Exp_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(x, destination)); + RunForEachSpecialValue(() => + { + TensorPrimitives.Exp(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); + } + }, x); } [Theory] - [InlineData(new float[] { 3, 1, .2f }, new float[] { 0.8360188f, 0.11314284f, 0.05083836f })] - [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] - [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] - [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f})] - public static void SoftMax(float[] x, float[] expectedResult) + [MemberData(nameof(TensorLengths))] + public static void Exp_ThrowsForTooShortDestination(int tensorLength) { - using BoundedMemory dest = CreateTensor(x.Length); - TensorPrimitives.SoftMax(x, dest); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - for (int i = 0; i < x.Length; i++) - { - Assert.Equal(expectedResult[i], dest[i], Tolerance); - } + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); } [Fact] - public static void SoftMax_DestinationLongerThanSource() + public static void Exp_ThrowsForOverlapppingInputsWithOutputs() { - float[] x = [3, 1, .2f]; - float[] expectedResult = [0.8360188f, 0.11314284f, 0.05083836f]; - using BoundedMemory dest = CreateTensor(x.Length + 1); - TensorPrimitives.SoftMax(x, dest); - - for (int i = 0; i < x.Length; i++) - { - Assert.Equal(expectedResult[i], dest[i], Tolerance); - } + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region IndexOfMax [Fact] - public static void SoftMax_ThrowsForEmptyInput() + public static void IndexOfMax_ReturnsNegative1OnEmpty() { - AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Equal(-1, TensorPrimitives.IndexOfMax(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) + public static void IndexOfMax(int tensorLength) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(x, destination)); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + } } [Theory] - [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] - [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] - [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] - public static void Sigmoid(float[] x, float[] expectedResult) - { - using BoundedMemory dest = CreateTensor(x.Length); - TensorPrimitives.Sigmoid(x, dest); - - for (int i = 0; i < x.Length; i++) - { - Assert.Equal(expectedResult[i], dest[i], Tolerance); - } - } - - [Fact] - public static void Sigmoid_DestinationLongerThanSource() + [MemberData(nameof(TensorLengths))] + public static void IndexOfMax_FirstNaNReturned(int tensorLength) { - float[] x = [-5, -4.5f, -4]; - float[] expectedResult = [0.0066f, 0.0109f, 0.0179f]; - using BoundedMemory dest = CreateTensor(x.Length + 1); - - TensorPrimitives.Sigmoid(x, dest); - - float originalLast = dest[dest.Length - 1]; - for (int i = 0; i < x.Length; i++) + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { - Assert.Equal(expectedResult[i], dest[i], Tolerance); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); } - Assert.Equal(originalLast, dest[dest.Length - 1]); } [Fact] - public static void Sigmoid_ThrowsForEmptyInput() + public static void IndexOfMax_Negative0LesserThanPositive0() { - AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Equal(1, TensorPrimitives.IndexOfMax([-0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f])); + Assert.Equal(4, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMax([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMax([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMax([-1, -0f, 1])); } + #endregion + #region IndexOfMaxMagnitude [Fact] - public static void IndexOfMax_ReturnsNegative1OnEmpty() + public static void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() { - Assert.Equal(-1, TensorPrimitives.IndexOfMax(ReadOnlySpan.Empty)); + Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMax(int tensorLength) + public static void IndexOfMaxMagnitude(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)) + 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMax_FirstNaNReturned(int tensorLength) + public static void IndexOfMaxMagnitude_FirstNaNReturned(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); x[expected] = float.NaN; x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); } } [Fact] - public static void IndexOfMax_Negative0LesserThanPositive0() + public static void IndexOfMaxMagnitude_Negative0LesserThanPositive0() { - Assert.Equal(1, TensorPrimitives.IndexOfMax([-0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f])); - Assert.Equal(4, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f, +0f, +0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMax([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMax([-1, -0f])); - Assert.Equal(2, TensorPrimitives.IndexOfMax([-1, -0f, 1])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); } + #endregion + #region IndexOfMin [Fact] public static void IndexOfMin_ReturnsNegative1OnEmpty() { @@ -1054,106 +1082,202 @@ public static void IndexOfMin_Negative0LesserThanPositive0() Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f])); Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f, 1])); } + #endregion + #region IndexOfMinMagnitude [Fact] - public static void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() + public static void IndexOfMinMagnitude_ReturnsNegative1OnEmpty() { - Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); + Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitude(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMaxMagnitude(int tensorLength) + public static void IndexOfMinMagnitude(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + using BoundedMemory x = CreateTensor(tensorLength); + for (int i = 0; i < x.Length; i++) + { + x[i] = i % 2 == 0 ? 42 : -42; + } + + x[expected] = -41; + + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMaxMagnitude_FirstNaNReturned(int tensorLength) + public static void IndexOfMinMagnitude_FirstNaNReturned(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); x[expected] = float.NaN; x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); } } [Fact] - public static void IndexOfMaxMagnitude_Negative0LesserThanPositive0() + public static void IndexOfMinMagnitude_Negative0LesserThanPositive0() { - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); - Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); + Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f, 1])); } + #endregion - [Fact] - public static void IndexOfMinMagnitude_ReturnsNegative1OnEmpty() + #region Log + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log(int tensorLength) { - Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitude(ReadOnlySpan.Empty)); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Log(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Log(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(xOrig[i]), x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMinMagnitude(int tensorLength) + public static void Log_SpecialValues(int tensorLength) { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => { - using BoundedMemory x = CreateTensor(tensorLength); - for (int i = 0; i < x.Length; i++) + TensorPrimitives.Log(x, destination); + for (int i = 0; i < tensorLength; i++) { - x[i] = i % 2 == 0 ? 42 : -42; + AssertEqualTolerance(MathF.Log(x[i]), destination[i]); } + }, x); + } - x[expected] = -41; + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); + } + + [Fact] + public static void Log_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Log2 + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Log2(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void IndexOfMinMagnitude_FirstNaNReturned(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2_InPlace(int tensorLength) { - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Log2(x, x); + + for (int i = 0; i < tensorLength; i++) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = float.NaN; - x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); + AssertEqualTolerance(MathF.Log(xOrig[i], 2), x[i]); } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log2_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Log2(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log2_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(x, destination)); + } + [Fact] - public static void IndexOfMinMagnitude_Negative0LesserThanPositive0() + public static void Log2_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, -0f, -0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f, -0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f, 1])); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region Max [Fact] - public static void Max_ThrowsForEmpty() + public static void Max_Tensor_ThrowsForEmpty() { Assert.Throws(() => TensorPrimitives.Max(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Max(int tensorLength) + public static void Max_Tensor(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); @@ -1164,23 +1288,45 @@ public static void Max(int tensorLength) { max = Math.Max(max, f); } + Assert.Equal(max, TensorPrimitives.Max(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Max_NanReturned(int tensorLength) + public static void Max_Tensor_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float max = float.NegativeInfinity; + foreach (float f in x.Span) + { + max = Math.Max(max, f); + } + + Assert.Equal(max, TensorPrimitives.Max(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { + FillTensor(x); x[expected] = float.NaN; Assert.Equal(float.NaN, TensorPrimitives.Max(x)); } } [Fact] - public static void Max_Negative0LesserThanPositive0() + public static void Max_Tensor_Negative0LesserThanPositive0() { Assert.Equal(+0f, TensorPrimitives.Max([-0f, +0f])); Assert.Equal(+0f, TensorPrimitives.Max([+0f, -0f])); @@ -1200,7 +1346,56 @@ public static void Max_TwoTensors(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Max(x[i], y[i]), destination[i], Tolerance); + AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Max_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Max(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Max(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.Max(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); + } + + TensorPrimitives.Max(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(y[i], x[i]), destination[i]); } } @@ -1213,6 +1408,7 @@ public static void Max_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) using BoundedMemory destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.Max(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Max(y, x, destination)); } [Theory] @@ -1227,467 +1423,1570 @@ public static void Max_TwoTensors_ThrowsForTooShortDestination(int tensorLength) } [Fact] - public static void MaxMagnitude_ThrowsForEmpty() + public static void Max_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region MaxMagnitude + [Fact] + public static void MaxMagnitude_Tensor_ThrowsForEmpty() { Assert.Throws(() => TensorPrimitives.MaxMagnitude(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude(int tensorLength) + public static void MaxMagnitude_Tensor(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - int index = 0; - for (int i = 0; i < x.Length; i++) + float maxMagnitude = x[0]; + foreach (float f in x.Span) { - if (MathF.Abs(x[i]) >= MathF.Abs(x[index])) - { - index = i; - } + maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); } - Assert.Equal(x[index], TensorPrimitives.MaxMagnitude(x), Tolerance); + Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float maxMagnitude = x[0]; + foreach (float f in x.Span) + { + maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); + } + + Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.MaxMagnitude(x)); + } + } + + [Fact] + public static void MaxMagnitude_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([-0f, +0f])); + Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([+0f, -0f])); + Assert.Equal(-1, TensorPrimitives.MaxMagnitude([-1, -0f])); + Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f, 1])); + Assert.Equal(0f, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -0f, 0f])); + Assert.Equal(1, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -1, -0f, 0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MaxMagnitude(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MaxMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MaxMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.MaxMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); + } + + TensorPrimitives.MaxMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MaxMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.MaxMagnitude(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(x, y, destination)); + } + + [Fact] + public static void MaxMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Min + [Fact] + public static void Min_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Min(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Min(x)); + + float min = float.PositiveInfinity; + foreach (float f in x.Span) + { + min = Math.Min(min, f); + } + + Assert.Equal(min, TensorPrimitives.Min(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float min = float.PositiveInfinity; + foreach (float f in x.Span) + { + min = Math.Min(min, f); + } + + Assert.Equal(min, TensorPrimitives.Min(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.Min(x)); + } + } + + [Fact] + public static void Min_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(-0f, TensorPrimitives.Min([-0f, +0f])); + Assert.Equal(-0f, TensorPrimitives.Min([+0f, -0f])); + Assert.Equal(-1, TensorPrimitives.Min([-1, -0f])); + Assert.Equal(-1, TensorPrimitives.Min([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Min(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Min(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Min(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.Min(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); + } + + TensorPrimitives.Min(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Min(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Min(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(x, y, destination)); + } + + [Fact] + public static void Min_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region MinMagnitude + [Fact] + public static void MinMagnitude_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.MinMagnitude(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float minMagnitude = x[0]; + foreach (float f in x.Span) + { + minMagnitude = MathFMinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float minMagnitude = x[0]; + foreach (float f in x.Span) + { + minMagnitude = MathFMinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.MinMagnitude(x)); + } + } + + [Fact] + public static void MinMagnitude_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.MinMagnitude([-0f, +0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MinMagnitude(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MinMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MinMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.MinMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); + } + + TensorPrimitives.MinMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MinMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.MinMagnitude(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(x, y, destination)); + } + + [Fact] + public static void MinMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Multiply + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Multiply(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] * y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Multiply(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] * xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Multiply(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + } + + [Fact] + public static void Multiply_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Multiply(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] * y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Multiply(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] * y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + } + + [Fact] + public static void Multiply_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } + #endregion + + #region MultiplyAdd + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y[i]) + addend[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.MultiplyAdd(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * xOrig[i]) + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, z, y, destination)); + Assert.Throws(() => TensorPrimitives.MultiplyAdd(z, x, y, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float addend = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y[i]) + addend, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float addend = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, x, addend, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * xOrig[i]) + addend, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float addend = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y) + addend[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * y) + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Negate + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Negate(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(-x[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Negate(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(-xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Negate_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); + } + + [Fact] + public static void Negate_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Norm + [Theory] + [InlineData(new float[] { 1, 2, 3 }, 3.7416575f)] + [InlineData(new float[] { 3, 4 }, 5)] + [InlineData(new float[] { 3 }, 3)] + [InlineData(new float[] { 3, 4, 1, 2 }, 5.477226)] + [InlineData(new float[] { }, 0f)] + public static void Norm_KnownValues(float[] x, float expectedResult) + { + AssertEqualTolerance(expectedResult, TensorPrimitives.Norm(x)); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Norm(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float sumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) + { + sumOfSquares += x[i] * x[i]; + } + + AssertEqualTolerance(MathF.Sqrt(sumOfSquares), TensorPrimitives.Norm(x)); + } + #endregion + + #region Product + [Fact] + public static void Product_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Product(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Product(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float f = x[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i]; + } + + AssertEqualTolerance(f, TensorPrimitives.Product(x)); + } + + [Theory] + [InlineData(1, new float[] { 1 })] + [InlineData(-2, new float[] { 1, -2 })] + [InlineData(-6, new float[] { 1, -2, 3 })] + [InlineData(24, new float[] { 1, -2, 3, -4 })] + [InlineData(120, new float[] { 1, -2, 3, -4, 5 })] + [InlineData(-720, new float[] { 1, -2, 3, -4, 5, -6 })] + [InlineData(0, new float[] { 1, -2, 3, -4, 5, -6, 0 })] + [InlineData(0, new float[] { 0, 1, -2, 3, -4, 5, -6 })] + [InlineData(0, new float[] { 1, -2, 3, 0, -4, 5, -6 })] + [InlineData(float.NaN, new float[] { 1, -2, 3, float.NaN, -4, 5, -6 })] + public static void Product_KnownValues(float expected, float[] input) + { + Assert.Equal(expected, TensorPrimitives.Product(input)); + } + #endregion + + #region ProductOfDifferences + [Fact] + public static void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(43), CreateTensor(44))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ProductOfDifferences(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float f = x[0] - y[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i] - y[i]; + } + AssertEqualTolerance(f, TensorPrimitives.ProductOfDifferences(x, y)); + } + + [Theory] + [InlineData(0, new float[] {0 }, new float[] {0})] + [InlineData(0, new float[] {1 }, new float[] {1})] + [InlineData(1, new float[] {1 }, new float[] {0})] + [InlineData(-1, new float[] {0 }, new float[] {1})] + [InlineData(-1, new float[] {1, 2, 3, 4, 5 }, new float[] {2, 3, 4, 5, 6})] + [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] {0, 0, 0, 0, 0})] + [InlineData(-120, new float[] {0, 0, 0, 0, 0 }, new float[] {1, 2, 3, 4, 5})] + [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] {0, 0, 0, 0, 0})] + public static void ProductOfDifferences_KnownValues(float expected, float[] x, float[] y) + { + Assert.Equal(expected, TensorPrimitives.ProductOfDifferences(x, y)); + + } + #endregion + + #region ProductOfSums + [Fact] + public static void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(43), CreateTensor(44))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ProductOfSums(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float f = x[0] + y[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i] + y[i]; + } + AssertEqualTolerance(f, TensorPrimitives.ProductOfSums(x, y)); + } + + [Theory] + [InlineData(0, new float[] {0 }, new float[] { 0 })] + [InlineData(1, new float[] {0 }, new float[] { 1 })] + [InlineData(1, new float[] {1 }, new float[] { 0 })] + [InlineData(2, new float[] {1 }, new float[] { 1 })] + [InlineData(10395, new float[] {1, 2, 3, 4, 5 }, new float[] { 2, 3, 4, 5, 6 })] + [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] + [InlineData(120, new float[] {0, 0, 0, 0, 0 }, new float[] { 1, 2, 3, 4, 5 })] + [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] + public static void ProductOfSums_KnownValues(float expected, float[] x, float[] y) + { + Assert.Equal(expected, TensorPrimitives.ProductOfSums(x, y)); + } + #endregion + + #region Sigmoid + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Sigmoid(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sigmoid(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-xOrig[i])), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Sigmoid(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + } + }, x); + } + + [Theory] + [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] + [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] + [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] + public static void Sigmoid_KnownValues(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length); + TensorPrimitives.Sigmoid(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + } + + [Theory] + [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] + public static void Sigmoid_DestinationLongerThanSource(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length + 1); + + TensorPrimitives.Sigmoid(x, dest); + + float originalLast = dest[dest.Length - 1]; + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + Assert.Equal(originalLast, dest[dest.Length - 1]); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(x, destination)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_NanReturned(int tensorLength) + [Fact] + public static void Sigmoid_ThrowsForEmptyInput() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.MaxMagnitude(x)); - } + AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); } [Fact] - public static void MaxMagnitude_Negative0LesserThanPositive0() + public static void Sigmoid_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([-0f, +0f])); - Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([+0f, -0f])); - Assert.Equal(-1, TensorPrimitives.MaxMagnitude([-1, -0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f, 1])); - Assert.Equal(0f, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -0f, 0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -1, -0f, 0f, 1])); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region Sinh [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void MaxMagnitude_TwoTensors(int tensorLength) + public static void Sinh(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MaxMagnitude(x, y, destination); + TensorPrimitives.Sinh(x, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Abs(x[i]) >= MathF.Abs(y[i]) ? x[i] : y[i], destination[i], Tolerance); + AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Sinh_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - Assert.Throws(() => TensorPrimitives.MaxMagnitude(x, y, destination)); + TensorPrimitives.Sinh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(xOrig[i]), x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + public static void Sinh_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(x, y, destination)); - } + using BoundedMemory destination = CreateTensor(tensorLength); - [Fact] - public static void Min_ThrowsForEmpty() - { - Assert.Throws(() => TensorPrimitives.Min(ReadOnlySpan.Empty)); + RunForEachSpecialValue(() => + { + TensorPrimitives.Sinh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); + } + }, x); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min(int tensorLength) + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] + public static void Sinh_ValueRange(int vectorLengths, float element) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] x = new float[vectorLengths]; + float[] dest = new float[vectorLengths]; - Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Min(x)); + x.AsSpan().Fill(element); + TensorPrimitives.Sinh(x, dest); - float min = float.PositiveInfinity; - foreach (float f in x.Span) + float expected = MathF.Sinh(element); + foreach (float actual in dest) { - min = Math.Min(min, f); + AssertEqualTolerance(expected, actual); } - Assert.Equal(min, TensorPrimitives.Min(x)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Min_NanReturned(int tensorLength) + public static void Sinh_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.Min(x)); - } + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); } [Fact] - public static void Min_Negative0LesserThanPositive0() + public static void Sinh_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(-0f, TensorPrimitives.Min([-0f, +0f])); - Assert.Equal(-0f, TensorPrimitives.Min([+0f, -0f])); - Assert.Equal(-1, TensorPrimitives.Min([-1, -0f])); - Assert.Equal(-1, TensorPrimitives.Min([-1, -0f, 1])); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region SoftMax [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Min_TwoTensors(int tensorLength) + [MemberData(nameof(TensorLengths))] + public static void SoftMax(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Min(x, y, destination); + TensorPrimitives.SoftMax(x, destination); + float expSum = MemoryMarshal.ToEnumerable(x.Memory).Sum(MathF.Exp); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Min(x[i], y[i]), destination[i], Tolerance); + AssertEqualTolerance(MathF.Exp(x[i]) / expSum, destination[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void Min_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + public static void SoftMax_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - Assert.Throws(() => TensorPrimitives.Min(x, y, destination)); + TensorPrimitives.SoftMax(x, x); + + float expSum = xOrig.Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(xOrig[i]) / expSum, x[i]); + } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Min_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + [InlineData(new float[] { 3, 1, .2f }, new float[] { 0.8360188f, 0.11314284f, 0.05083836f })] + [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] + [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] + [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f })] + public static void SoftMax_KnownValues(float[] x, float[] expectedResult) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory dest = CreateTensor(x.Length); + TensorPrimitives.SoftMax(x, dest); - AssertExtensions.Throws("destination", () => TensorPrimitives.Min(x, y, destination)); + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } } [Fact] - public static void MinMagnitude_ThrowsForEmpty() - { - Assert.Throws(() => TensorPrimitives.MinMagnitude(ReadOnlySpan.Empty)); - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MinMagnitude(int tensorLength) + public static void SoftMax_DestinationLongerThanSource() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] x = [3, 1, .2f]; + float[] expectedResult = [0.8360188f, 0.11314284f, 0.05083836f]; + using BoundedMemory dest = CreateTensor(x.Length + 1); + TensorPrimitives.SoftMax(x, dest); - int index = 0; for (int i = 0; i < x.Length; i++) { - if (MathF.Abs(x[i]) < MathF.Abs(x[index])) - { - index = i; - } + AssertEqualTolerance(expectedResult[i], dest[i]); } - - Assert.Equal(x[index], TensorPrimitives.MinMagnitude(x), Tolerance); } [Theory] [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_NanReturned(int tensorLength) + public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.MinMagnitude(x)); - } + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(x, destination)); } [Fact] - public static void MinMagnitude_Negative0LesserThanPositive0() + public static void SoftMax_ThrowsForEmptyInput() { - Assert.Equal(0, TensorPrimitives.MinMagnitude([-0f, +0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([+0f, -0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f, 1])); + AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); + } + + [Fact] + public static void SoftMax_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region Subtract [Theory] [MemberData(nameof(TensorLengthsIncluding0))] - public static void MinMagnitude_TwoTensors(int tensorLength) + public static void Subtract_TwoTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MinMagnitude(x, y, destination); + TensorPrimitives.Subtract(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] - y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Subtract(x, x, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Abs(x[i]) < MathF.Abs(y[i]) ? x[i] : y[i], destination[i], Tolerance); + AssertEqualTolerance(xOrig[i] - xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + public static void Subtract_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.MinMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Subtract(y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + public static void Subtract_TwoTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } [Fact] - public static void Product_ThrowsForEmpty() + public static void Subtract_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Throws(() => TensorPrimitives.Product(ReadOnlySpan.Empty)); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Product(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); - float f = x[0]; - for (int i = 1; i < x.Length; i++) + TensorPrimitives.Subtract(x, y, destination); + + for (int i = 0; i < tensorLength; i++) { - f *= x[i]; + AssertEqualTolerance(x[i] - y, destination[i]); } - - Assert.Equal(f, TensorPrimitives.Product(x), Tolerance); - } - - [Fact] - public static void Product_KnownValues() - { - Assert.Equal(1, TensorPrimitives.Product([1])); - Assert.Equal(-2, TensorPrimitives.Product([1, -2])); - Assert.Equal(-6, TensorPrimitives.Product([1, -2, 3])); - Assert.Equal(24, TensorPrimitives.Product([1, -2, 3, -4])); - Assert.Equal(120, TensorPrimitives.Product([1, -2, 3, -4, 5])); - Assert.Equal(-720, TensorPrimitives.Product([1, -2, 3, -4, 5, -6])); - Assert.Equal(0, TensorPrimitives.Product([1, -2, 3, -4, 5, -6, 0])); - Assert.Equal(0, TensorPrimitives.Product([0, 1, -2, 3, -4, 5, -6])); - Assert.Equal(0, TensorPrimitives.Product([1, -2, 3, 0, -4, 5, -6])); - Assert.Equal(float.NaN, TensorPrimitives.Product([1, -2, 3, float.NaN, -4, 5, -6])); - } - - [Fact] - public static void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() - { - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(44), CreateTensor(43))); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(43), CreateTensor(44))); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void ProductOfDifferences(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); - float f = x[0] - y[0]; - for (int i = 1; i < x.Length; i++) + TensorPrimitives.Subtract(x, y, x); + + for (int i = 0; i < tensorLength; i++) { - f *= x[i] - y[i]; + AssertEqualTolerance(xOrig[i] - y, x[i]); } - Assert.Equal(f, TensorPrimitives.ProductOfDifferences(x, y), Tolerance); } - [Fact] - public static void ProductOfDifferences_KnownValues() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensorLength) { - Assert.Equal(0, TensorPrimitives.ProductOfDifferences([0], [0])); - Assert.Equal(0, TensorPrimitives.ProductOfDifferences([1], [1])); - Assert.Equal(1, TensorPrimitives.ProductOfDifferences([1], [0])); - Assert.Equal(-1, TensorPrimitives.ProductOfDifferences([0], [1])); - Assert.Equal(-1, TensorPrimitives.ProductOfDifferences([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])); - Assert.Equal(120, TensorPrimitives.ProductOfDifferences([1, 2, 3, 4, 5], [0, 0, 0, 0, 0])); - Assert.Equal(-120, TensorPrimitives.ProductOfDifferences([0, 0, 0, 0, 0], [1, 2, 3, 4, 5])); - Assert.Equal(float.NaN, TensorPrimitives.ProductOfDifferences([1, 2, float.NaN, 4, 5], [0, 0, 0, 0, 0])); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); } [Fact] - public static void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() + public static void Subtract_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(44), CreateTensor(43))); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(43), CreateTensor(44))); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); } + #endregion + #region Sum [Theory] [MemberData(nameof(TensorLengths))] - public static void ProductOfSums(int tensorLength) + public static void Sum(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float f = x[0] + y[0]; - for (int i = 1; i < x.Length; i++) + AssertEqualTolerance(MemoryMarshal.ToEnumerable(x.Memory).Sum(), TensorPrimitives.Sum(x)); + + float sum = 0; + foreach (float f in x.Span) { - f *= x[i] + y[i]; + sum += f; } - Assert.Equal(f, TensorPrimitives.ProductOfSums(x, y), Tolerance); + AssertEqualTolerance(sum, TensorPrimitives.Sum(x)); } - [Fact] - public static void ProductOfSums_KnownValues() + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(6, new float[] { 1, 2, 3 })] + [InlineData(0, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void Sum_KnownValues(float expected, float[] x) { - Assert.Equal(0, TensorPrimitives.ProductOfSums([0], [0])); - Assert.Equal(1, TensorPrimitives.ProductOfSums([0], [1])); - Assert.Equal(1, TensorPrimitives.ProductOfSums([1], [0])); - Assert.Equal(2, TensorPrimitives.ProductOfSums([1], [1])); - Assert.Equal(10395, TensorPrimitives.ProductOfSums([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])); - Assert.Equal(120, TensorPrimitives.ProductOfSums([1, 2, 3, 4, 5], [0, 0, 0, 0, 0])); - Assert.Equal(120, TensorPrimitives.ProductOfSums([0, 0, 0, 0, 0], [1, 2, 3, 4, 5])); - Assert.Equal(float.NaN, TensorPrimitives.ProductOfSums([1, 2, float.NaN, 4, 5], [0, 0, 0, 0, 0])); + Assert.Equal(expected, TensorPrimitives.Sum(x)); } + #endregion + #region SumOfMagnitudes [Theory] [MemberData(nameof(TensorLengths))] - public static void Sum(int tensorLength) + public static void SumOfMagnitudes(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Sum(x), Tolerance); + AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.SumOfMagnitudes(x)); float sum = 0; foreach (float f in x.Span) { - sum += f; + sum += MathF.Abs(f); } - Assert.Equal(sum, TensorPrimitives.Sum(x), Tolerance); + AssertEqualTolerance(sum, TensorPrimitives.SumOfMagnitudes(x)); } - [Fact] - public static void Sum_KnownValues() + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(6, new float[] { 1, 2, 3 })] + [InlineData(6, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void SumOfMagnitudes_KnownValues(float expected, float[] x) { - Assert.Equal(0, TensorPrimitives.Sum([0])); - Assert.Equal(1, TensorPrimitives.Sum([0, 1])); - Assert.Equal(6, TensorPrimitives.Sum([1, 2, 3])); - Assert.Equal(0, TensorPrimitives.Sum([-3, 0, 3])); - Assert.Equal(float.NaN, TensorPrimitives.Sum([-3, float.NaN, 3])); + Assert.Equal(expected, TensorPrimitives.SumOfMagnitudes(x)); } + #endregion + #region SumOfSquares [Theory] [MemberData(nameof(TensorLengths))] public static void SumOfSquares(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), v => v * v), TensorPrimitives.SumOfSquares(x), Tolerance); + AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), v => v * v), TensorPrimitives.SumOfSquares(x)); float sum = 0; foreach (float f in x.Span) { sum += f * f; } - Assert.Equal(sum, TensorPrimitives.SumOfSquares(x), Tolerance); + AssertEqualTolerance(sum, TensorPrimitives.SumOfSquares(x)); } - [Fact] - public static void SumOfSquares_KnownValues() + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(14, new float[] { 1, 2, 3 })] + [InlineData(18, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void SumOfSquares_KnownValues(float expected, float[] x) { - Assert.Equal(0, TensorPrimitives.SumOfSquares([0])); - Assert.Equal(1, TensorPrimitives.SumOfSquares([0, 1])); - Assert.Equal(14, TensorPrimitives.SumOfSquares([1, 2, 3])); - Assert.Equal(18, TensorPrimitives.SumOfSquares([-3, 0, 3])); - Assert.Equal(float.NaN, TensorPrimitives.SumOfSquares([-3, float.NaN, 3])); + Assert.Equal(expected, TensorPrimitives.SumOfSquares(x)); } + #endregion + #region Tanh [Theory] - [MemberData(nameof(TensorLengths))] - public static void SumOfMagnitudes(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.SumOfMagnitudes(x), Tolerance); + TensorPrimitives.Tanh(x, destination); - float sum = 0; - foreach (float f in x.Span) + for (int i = 0; i < tensorLength; i++) { - sum += MathF.Abs(f); + AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); } - Assert.Equal(sum, TensorPrimitives.SumOfMagnitudes(x), Tolerance); } - [Fact] - public static void SumOfMagnitudes_KnownValues() + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh_InPlace(int tensorLength) { - Assert.Equal(0, TensorPrimitives.SumOfMagnitudes([0])); - Assert.Equal(1, TensorPrimitives.SumOfMagnitudes([0, 1])); - Assert.Equal(6, TensorPrimitives.SumOfMagnitudes([1, 2, 3])); - Assert.Equal(6, TensorPrimitives.SumOfMagnitudes([-3, 0, 3])); - Assert.Equal(float.NaN, TensorPrimitives.SumOfMagnitudes([-3, float.NaN, 3])); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Tanh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(xOrig[i]), x[i]); + } } [Theory] - [MemberData(nameof(TensorLengthsIncluding0))] - public static void Abs(int tensorLength) + [MemberData(nameof(TensorLengths))] + public static void Tanh_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Abs(x, destination); + RunForEachSpecialValue(() => + { + TensorPrimitives.Tanh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); + } + }, x); + } - for (int i = 0; i < x.Length; i++) + [Theory] + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -11f, 11f, 0.2f })] + public static void Tanh_ValueRange(int vectorLengths, float element) + { + float[] x = new float[vectorLengths]; + float[] dest = new float[vectorLengths]; + + x.AsSpan().Fill(element); + TensorPrimitives.Tanh(x, dest); + + float expected = MathF.Tanh(element); + foreach (float actual in dest) { - Assert.Equal(MathF.Abs(x[i]), destination[i], Tolerance); + AssertEqualTolerance(expected, actual); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void Abs_ThrowsForTooShortDestination(int tensorLength) + public static void Tanh_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); + } + + [Fact] + public static void Tanh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs index 113f26048d352..06ab341db1624 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs @@ -8,15 +8,16 @@ namespace System.Numerics.Tensors.Tests { public static partial class TensorPrimitivesTests { + #region ConvertToHalf [Theory] - [InlineData(0)] - [MemberData(nameof(TensorLengths))] + [MemberData(nameof(TensorLengthsIncluding0))] public static void ConvertToHalf(int tensorLength) { using BoundedMemory source = CreateAndFillTensor(tensorLength); foreach (int destLength in new[] { source.Length, source.Length + 1 }) { - Half[] destination = new Half[destLength]; + using BoundedMemory destination = BoundedMemory.Allocate(destLength); + destination.Span.Fill(Half.Zero); TensorPrimitives.ConvertToHalf(source, destination); @@ -35,6 +36,28 @@ public static void ConvertToHalf(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToHalf_SpecialValues(int tensorLength) + { + using BoundedMemory source = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + // NaN, infinities, and 0s + source[s_random.Next(source.Length)] = float.NaN; + source[s_random.Next(source.Length)] = float.PositiveInfinity; + source[s_random.Next(source.Length)] = float.NegativeInfinity; + source[s_random.Next(source.Length)] = 0; + source[s_random.Next(source.Length)] = float.NegativeZero; + + TensorPrimitives.ConvertToHalf(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((Half)source[i], destination[i]); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength) @@ -44,13 +67,14 @@ public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToHalf(source, destination)); } + #endregion + #region ConvertToSingle [Theory] - [InlineData(0)] - [MemberData(nameof(TensorLengths))] + [MemberData(nameof(TensorLengthsIncluding0))] public static void ConvertToSingle(int tensorLength) { - Half[] source = new Half[tensorLength]; + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); for (int i = 0; i < source.Length; i++) { source[i] = (Half)s_random.NextSingle(); @@ -77,6 +101,32 @@ public static void ConvertToSingle(int tensorLength) } } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToSingle_SpecialValues(int tensorLength) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + for (int i = 0; i < source.Length; i++) + { + source[i] = (Half)s_random.NextSingle(); + } + + using BoundedMemory destination = CreateTensor(tensorLength); + + // NaN, infinities, and 0s + source[s_random.Next(source.Length)] = Half.NaN; + source[s_random.Next(source.Length)] = Half.PositiveInfinity; + source[s_random.Next(source.Length)] = Half.NegativeInfinity; + source[s_random.Next(source.Length)] = Half.Zero; + source[s_random.Next(source.Length)] = Half.NegativeZero; + + TensorPrimitives.ConvertToSingle(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((float)source[i], destination[i]); + } + } [Theory] [MemberData(nameof(TensorLengths))] @@ -87,5 +137,6 @@ public static void ConvertToSingle_ThrowsForTooShortDestination(int tensorLength AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToSingle(source, destination)); } + #endregion } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Half.cs b/src/libraries/System.Private.CoreLib/src/System/Half.cs index 8daa37bbab576..cd3e6ab3ed73c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Half.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Half.cs @@ -1044,7 +1044,7 @@ public static explicit operator float(Half value) // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) const uint ExponentOffset = 0x3800_0000u; // Mask for sign bit in Single - const uint FloatSignMask = float.SignMask; + const uint SingleSignMask = float.SignMask; // Mask for exponent bits in Half const uint HalfExponentMask = BiasedExponentMask; // Mask for bits in Single converted from Half @@ -1052,7 +1052,7 @@ public static explicit operator float(Half value) // Extract the internal representation of value short valueInInt16Bits = BitConverter.HalfToInt16Bits(value); // Extract sign bit of value - uint sign = (uint)(int)valueInInt16Bits & FloatSignMask; + uint sign = (uint)(int)valueInInt16Bits & SingleSignMask; // Copy sign bit to upper bits uint bitValueInProcess = (uint)valueInInt16Bits; // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)